tf.contrib.learn 模型导出
TaaS 平台支持将训练得到的模型进行托管,TaaS 平台会加载模型信息,然后基于该模型来对外提供 gRPC 和 RESTful API。
TaaS 平台提供了 ModelExportSpec 来提供模型导出配置。在导出 tf.contrib.learn 模型的时候,我们需要提供导出路径(export_dir
)和特征向量(features
)。
# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
import pandas as pd
import tensorflow as tf
import tensorflow.contrib.learn as learn
from caicloud.clever.tensorflow import dist_base
from caicloud.clever.tensorflow import model_exporter
tf.app.flags.DEFINE_string("data_dir",
".",
"data directory path.")
tf.app.flags.DEFINE_string("export_dir",
None,
"model export directory path.")
FLAGS = tf.app.flags.FLAGS
tf.logging.set_verbosity(tf.logging.INFO)
COLUMNS = ["crim", "zn", "indus", "nox", "rm", "age",
"dis", "tax", "ptratio", "medv"]
FEATURES = ["crim", "zn", "indus", "nox", "rm",
"age", "dis", "tax", "ptratio"]
LABEL = "medv"
# 加载训练数据集和测试数据集
training_set = pd.read_csv("{0}/boston_train.csv".format(FLAGS.data_dir),
skipinitialspace=True,
skiprows=1, names=COLUMNS)
test_set = pd.read_csv("{0}/boston_test.csv".format(FLAGS.data_dir),
skipinitialspace=True,
skiprows=1, names=COLUMNS)
# 输入数据的特征列表
feature_cols = [tf.contrib.layers.real_valued_column(k)
for k in FEATURES]
# 定义深度神经网络回归模型
run_config = tf.contrib.learn.RunConfig(
save_checkpoints_secs=dist_base.cfg.save_checkpoints_secs)
regressor = tf.contrib.learn.DNNRegressor(
feature_columns=feature_cols,
hidden_units=[10, 10],
model_dir=dist_base.cfg.logdir,
config=run_config)
def input_fn(data_set):
feature_cols = {k: tf.constant(data_set[k].values)
for k in FEATURES}
labels = tf.constant(data_set[LABEL].values)
return feature_cols, labels
# 如果指定了 export_dir 命令行参数,则我们通过调用一个 ModelExportSpec 来指定相关的模型
# 导出配置。对于 tf.contrib.learn 模型训练任务,我们只需要指定导出路径(export_dir)和特征列表(features)。
model_export_spec = None
if FLAGS.export_dir is not None:
model_export_spec = model_exporter.ModelExportSpec(
export_dir = FLAGS.export_dir,
features = feature_cols)
# 定义 dist_base.Experiment 对象来分布式执行 tf.contrib.learn 模型训练任务。
# 指定相对应的 esimator 对象,将要传递给 estimator.fit() 方法进行模型训练的
# input_fn 参数传递给 dist_base.Experiment 的 train_input_fn 参数。将执行模型
# 评估时传递给 estimator.evaluate() 方法的 input_fn 参数传递给 dist_base.Experiment
# 对象的 eval_input_fn 参数。
exp = dist_base.Experiment(
estimator = regressor,
train_input_fn = lambda: input_fn(training_set),
eval_input_fn = lambda: input_fn(test_set),
eval_steps = 1,
model_export_spec = model_export_spec)
# 调用 dist.Experiment.run() 方法来分布式执行 tf.contrib.learn 模型训练任务。
exp.run()