Experiement 类

支持的 caicloud.tensorflow 最低版本:2.1.0

所在模块:caicloud.clever.tensorflow.dist_base

保存 tf.contrib.learn 的模型训练的所有信息。

要支持 tf.contrib.learn 的模型训练任务在 TaaS 上分布式执行,需要创建一个 Experiment 对对象,然后调用其 run() 方法。基本代码框架如下:

import tensorflow as tf
import tensorflow.contrib.learn as learn

from caicloud.clever.tensorflow import dist_base

estimator = learn.Estimator(model_fn=...)

# input_fn for estimator.fit() or estimator.evaluate()
def input_fn():
    ...

exp = dist_base.Experiment(
    estimator = estimator,
    train_input_fn = input_fn,
    eval_input_fn = input_fn,
    eval_steps = 1)
exp.run()

__init__

__init__(
    estimator,
    train_input_fn,
    eval_input_fn,
    eval_metrics=None,
    eval_steps=100,
    train_monitors=None,
    eval_hooks=None,
    eval_delay_secs=60,
    model_export_spec=None)

创建一个 tf.contrib.learn 模型训练任务分布式运行的 Experiment 对象。

参数说明:

  • estimator:实现了 tf.contrib.learn.Estimator 接口的对象。

  • train_input_fn:函数对象,用于 estimator.fit() 执行模型训练时返回 features 和 labels。

  • evalu_input_fn:函数对象,用于 estimator.evaluate() 执行模型评估时返回 features 和 labels。

  • eval_metrics:字符串到 metrics 函数的字典,在执行 esimator.evaluate() 进行模型评估时传递给 metrics 参数。

  • eval_steps:默认情况下,estimator.evalute() 在执行模型评估时会持续执行到输入数据耗尽。可以通过 eval_steps 参数来指定 estimator.evalute() 执行多少轮。

  • train_monitors:传递给 estimator.fit() 方法的 monitor 列表。

  • eval_hooks:传递给 estimator.evaluate() 方法的 SessionRunHook 列表。

  • eval_delay_secs:在模型训练结束之后会执行一次模型评估,该参数用于指定在执行模型评估之间需要等待多长时间,单位为秒。

  • model_export_specModelExportSpec 对象,用于模型导出,具体说明请参考 ModelExportSpec

run

run()

执行 tf.contrib.learn 模型训练任务。

results matching ""

    No results matching ""