Experiement 类
支持的 caicloud.tensorflow 最低版本:2.1.0
所在模块:caicloud.clever.tensorflow.dist_base
。
保存 tf.contrib.learn 的模型训练的所有信息。
要支持 tf.contrib.learn 的模型训练任务在 TaaS 上分布式执行,需要创建一个 Experiment
对对象,然后调用其 run()
方法。基本代码框架如下:
import tensorflow as tf
import tensorflow.contrib.learn as learn
from caicloud.clever.tensorflow import dist_base
estimator = learn.Estimator(model_fn=...)
# input_fn for estimator.fit() or estimator.evaluate()
def input_fn():
...
exp = dist_base.Experiment(
estimator = estimator,
train_input_fn = input_fn,
eval_input_fn = input_fn,
eval_steps = 1)
exp.run()
__init__
__init__(
estimator,
train_input_fn,
eval_input_fn,
eval_metrics=None,
eval_steps=100,
train_monitors=None,
eval_hooks=None,
eval_delay_secs=60,
model_export_spec=None)
创建一个 tf.contrib.learn 模型训练任务分布式运行的 Experiment 对象。
参数说明:
estimator
:实现了 tf.contrib.learn.Estimator 接口的对象。train_input_fn
:函数对象,用于estimator.fit()
执行模型训练时返回 features 和 labels。evalu_input_fn
:函数对象,用于estimator.evaluate()
执行模型评估时返回 features 和 labels。eval_metrics
:字符串到 metrics 函数的字典,在执行esimator.evaluate()
进行模型评估时传递给 metrics 参数。eval_steps
:默认情况下,estimator.evalute()
在执行模型评估时会持续执行到输入数据耗尽。可以通过eval_steps
参数来指定estimator.evalute()
执行多少轮。train_monitors
:传递给estimator.fit()
方法的 monitor 列表。eval_hooks
:传递给estimator.evaluate()
方法的SessionRunHook
列表。eval_delay_secs
:在模型训练结束之后会执行一次模型评估,该参数用于指定在执行模型评估之间需要等待多长时间,单位为秒。model_export_spec
:ModelExportSpec
对象,用于模型导出,具体说明请参考 ModelExportSpec。
run
run()
执行 tf.contrib.learn 模型训练任务。