训练数据的 feed 机制
在 快速入门 中,我们看到了如何在 TaaS 上分布式执行 tf.contrib.learn 实现的 Boston House 样例模型。但是,在那个实现中,存在两个性能优化点:
- 将训练数据集和测试数据集全部读入内存,如果数据集比较大的话,会占用比较多的内存空间。
input_fn
函数直接将训练数据集或者测试数据集以tf.constants
的方式提供,这样会导致 TensorFlow 的 Graph 比较大,加载比较慢,占用内存也会比较大。
在本文中,我们将针对第二个问题来说明一下如何在 tf.contrib.learn 中使用 feed 机制。
在 Experiment 中,我们看到创建 Experiment
对象时可以传递两个参数:train_monitors
和 eval_hooks
。这两个参数都可以传递一个 tf.train.SessionRunHook
列表,分别用于模型训练和模型评估时分别传递给 estimator.fit()
和 estimator.evaluate()
。
tf.train.FeedFnHook 提供了在执行模型训练或者模型评估是向 tf.Session
的 run
方法提供 feed_dict
参数的机制。 创建 tf.train.FeedFnHook
对象时提供一个 feed_fn
参数,该参数会在每轮模型训练或者评估时被调用一次,然后将其返回值传递给 tf.Session
的 run
方法的 feed_dict
参数。
下面我们通过修改 快速入门 中 Boston House 样例任务代码,通过 feed 机制来在训练或者模型评估时提供数据。
# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
import pandas as pd
import tensorflow as tf
import tensorflow.contrib.learn as learn
from caicloud.clever.tensorflow import dist_base
tf.app.flags.DEFINE_string("data_dir",
".",
"data directory path.")
FLAGS = tf.app.flags.FLAGS
tf.logging.set_verbosity(tf.logging.INFO)
COLUMNS = ["crim", "zn", "indus", "nox", "rm", "age",
"dis", "tax", "ptratio", "medv"]
FEATURES = ["crim", "zn", "indus", "nox", "rm",
"age", "dis", "tax", "ptratio"]
LABEL = "medv"
# 加载训练数据集和测试数据集
training_set = pd.read_csv("{0}/boston_train.csv".format(FLAGS.data_dir),
skipinitialspace=True,
skiprows=1, names=COLUMNS)
test_set = pd.read_csv("{0}/boston_test.csv".format(FLAGS.data_dir),
skipinitialspace=True,
skiprows=1, names=COLUMNS)
feature_cols = [tf.contrib.layers.real_valued_column(k)
for k in FEATURES]
run_config = tf.contrib.learn.RunConfig(
save_checkpoints_secs=dist_base.cfg.save_checkpoints_secs)
regressor = tf.contrib.learn.DNNRegressor(
feature_columns=feature_cols,
hidden_units=[10, 10],
model_dir=dist_base.cfg.logdir,
config=run_config)
_input_tensors = None
_output_tensor = None
def input_fn():
""" 生成模型输入和预期输出的函数。
estimator 执行模型训练(fit)或者模型评估(evaluate)时调用该函数来生成模型的输入和
预期输出。在这里我们使用 tf.placeholder 来提供输入和输出,在真正执行模型训练或者模型
评估的时候,再通过 feed 的机制将训练数据或者测试数据 feed 到相对应的 placeholder 中。
"""
global _input_tensors, _output_tensor
_input_tensors = {k: tf.placeholder(dtype=tf.float64, shape=[None], name=k)
for k in FEATURES}
_output_tensor = tf.placeholder(dtype=tf.float64, shape=[None], name=LABEL)
return _input_tensors, _output_tensor
def feed_fn(data_set):
""" feed 函数。
estimator 在执行模型训练或者模型评估的时候,通过该函数获取到相对应数据集的 feed_dict,
然后传递给 tf.Session 的 run 方法。
"""
global _input_tensors, _output_tensor
feed_dict = {_input_tensors[k]: data_set[k].values
for k in FEATURES}
feed_dict[_output_tensor] = data_set[LABEL].values
return feed_dict
# tf.train.FeedFnHook 提供了在执行模型训练或者模型评估是向 tf.Session 的 run 方法提供
# feed_dict 参数的机制。 创建 tf.train.FeedFnHook 对象时提供一个 feed_fn 参数,该参数会
# 在每轮模型训练或者评估时被调用一次,然后将其返回值传递给 tf.Session 的 run 方法。
train_monitors = [tf.train.FeedFnHook(lambda: feed_fn(training_set))]
eval_hooks = [tf.train.FeedFnHook(lambda: feed_fn(test_set))]
exp = dist_base.Experiment(
estimator = regressor,
train_input_fn = input_fn,
eval_input_fn = input_fn,
train_monitors = train_monitors,
eval_hooks = eval_hooks,
eval_steps = 1)
exp.run()