tf.contrib.learn 模型导出

TaaS 平台支持将训练得到的模型进行托管,TaaS 平台会加载模型信息,然后基于该模型来对外提供 gRPC 和 RESTful API。

TaaS 平台提供了 ModelExportSpec 来提供模型导出配置。在导出 tf.contrib.learn 模型的时候,我们需要提供导出路径(export_dir)和特征向量(features)。

# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import itertools

import pandas as pd
import tensorflow as tf
import tensorflow.contrib.learn as learn

from caicloud.clever.tensorflow import dist_base
from caicloud.clever.tensorflow import model_exporter

tf.app.flags.DEFINE_string("data_dir",
                           ".",
                           "data directory path.")
tf.app.flags.DEFINE_string("export_dir",
                           None,
                           "model export directory path.")
FLAGS = tf.app.flags.FLAGS

tf.logging.set_verbosity(tf.logging.INFO)

COLUMNS = ["crim", "zn", "indus", "nox", "rm", "age",
           "dis", "tax", "ptratio", "medv"]
FEATURES = ["crim", "zn", "indus", "nox", "rm",
            "age", "dis", "tax", "ptratio"]
LABEL = "medv"

# 加载训练数据集和测试数据集
training_set = pd.read_csv("{0}/boston_train.csv".format(FLAGS.data_dir),
                           skipinitialspace=True,
                           skiprows=1, names=COLUMNS)
test_set = pd.read_csv("{0}/boston_test.csv".format(FLAGS.data_dir),
                       skipinitialspace=True,
                       skiprows=1, names=COLUMNS)

# 输入数据的特征列表
feature_cols = [tf.contrib.layers.real_valued_column(k)
                for k in FEATURES]

# 定义深度神经网络回归模型
run_config = tf.contrib.learn.RunConfig(
    save_checkpoints_secs=dist_base.cfg.save_checkpoints_secs)
regressor = tf.contrib.learn.DNNRegressor(
    feature_columns=feature_cols,
    hidden_units=[10, 10],
    model_dir=dist_base.cfg.logdir,
    config=run_config)

def input_fn(data_set):
    feature_cols = {k: tf.constant(data_set[k].values)
                    for k in FEATURES}
    labels = tf.constant(data_set[LABEL].values)
    return feature_cols, labels

# 如果指定了 export_dir 命令行参数,则我们通过调用一个 ModelExportSpec 来指定相关的模型
# 导出配置。对于 tf.contrib.learn 模型训练任务,我们只需要指定导出路径(export_dir)和特征列表(features)。
model_export_spec = None
if FLAGS.export_dir is not None:
    model_export_spec = model_exporter.ModelExportSpec(
        export_dir = FLAGS.export_dir,
        features = feature_cols)

# 定义 dist_base.Experiment 对象来分布式执行 tf.contrib.learn 模型训练任务。
# 指定相对应的 esimator 对象,将要传递给 estimator.fit() 方法进行模型训练的
# input_fn 参数传递给 dist_base.Experiment 的 train_input_fn 参数。将执行模型
# 评估时传递给 estimator.evaluate() 方法的 input_fn 参数传递给 dist_base.Experiment
# 对象的 eval_input_fn 参数。
exp = dist_base.Experiment(
    estimator = regressor,
    train_input_fn = lambda: input_fn(training_set),
    eval_input_fn = lambda: input_fn(test_set),
    eval_steps = 1,
    model_export_spec = model_export_spec)

# 调用 dist.Experiment.run() 方法来分布式执行 tf.contrib.learn 模型训练任务。
exp.run()

results matching ""

    No results matching ""