class DistTensorflowRunner

所在模块 caicloud.clever.tensorflow.dist_base

模型训练的分布式运行器。

我们在模型任务中需要定义一个 DistTensorflowRunner 对象,指定模型定义函数、自定义初始化函数等参数。然后调用该对象的 run 方法来执行模型训练。

__init__

__init__(model_fn, gen_init_fn=None, after_train_hook=None)

创建一个分布式 TensorFlow 运行器实例。

参数说明:

  • model_fn:我们实际业务模型定义的函数 ,不能为 None

    该函数用于构造用户业务 TensorFlow 模型的 Graph,接收两个参数:

    • sync:当前分布式运行是否采用参数同步更新模式。
    • num_replicas:分布式运行的 worker 副本个数。

    该函数返回值是一个 ModelFnHandler 对象。

  • gen_init_fn:用于生成用户自定义初始化方法的函数,不接收任何参数,可选,默认为 None。

    gen_init_fn 返回值是一个函数对象,如果不需要执行自定义初始化,则返回 None。该函数对象接收一个 tf.train.Scafford 对象和一个 tf.Session 对象作为参数。

    关于自定义初始化有两个需要注意的地方:

    1. 当 TaaS 平台执行 gen_init_fn 返回的初始化函数时,TensorFlow graph 已经被 finalized,不再允许添加新的 Operation,所以需要在 gen_init_fn 函数体中而不是在返回的函数中添加相关的 Operation(例如定义 tf.train.Saver)。
    2. TaaS 平台会优先从运行时指定的日志路径中查找最新的 checkpoint 文件,如果没有找到,则调用 gen_init_fn 返回的函数对象来进行模型初始化;否则,直接使用该 checkpoint 来初始化模型参数,而不调用 gen_init_fn 返回的函数对象。
  • after_train_hook:模型成功训练后钩子,可选,默认为 None。

    该钩子函数会在模型成功训练之后执行,该函数需要接收一个 tf.Session 对象作为参数。

返回值:

  • DistTensorflowRunner 对象

run

run(train_fn)

启动分布式模型训练。

参数说明:

  • train_fn:用于执行每一轮训练的具体操作的函数。

    该函数需要接收两个参数:

    • sessiontf.Session 对象。
    • num_global_step:当前训练所处的训练轮次。

    该函数返回一个 bool 值用于表示是否要终止模型训练。返回 True 表示终止模型训练,返回 False 表示继续模型训练。

results matching ""

    No results matching ""