模型导出
本节是在手写数字识别 MNIST 样例基础上进行实现的。
最终训练出来的模型,要能够对新数据进行预测功能,则需要将训练得到的模型导出来,并通过 Serving 的方式提供预测功能。
TaaS 平台支持在 model_fn
方法返回 ModelFnHandler
对象中设置模型导出的相关配置 ModelExportSpec
,例如模型的输入和输出 Tensors、以及相关辅助的资产文件和初始化操作。ModelExportSpec
定义在 caicloud.clever.tensorflow.model_exporter
模块中,所以在使用之前要先 import 该模块。
from caicloud.clever.tensorflow import model_exporter
TaaS 平台在模型训练结束后,会将训练得到的模型导出。导出的模型可以通过 TaaS 平台启动一个 Serving 服务,该 Serving 服务会加载导出的模型并提供 gRPC 和 RESTful API(Serving 服务目前只在私有云版本中支持,公有云版本后续也会持续支持的)。
我们修改手写数字识别 MNIST 样例中提供的 model_fn
方法:
def model_fn(sync, num_replicas):
# 模型定义操作
...
# 定义模型导出配置
model_export_spec = model_exporter.ModelExportSpec(
export_dir=FLAGS.export_dir,
input_tensors={"image": _input_images},
output_tensors={"logits": logits})
return dist_base.ModelFnHandler(
global_step=_global_step,
optimizer=optimizer,
model_export_spec=model_export_spec)
完整代码请参考 caicloud/tensorflow-tutorial。