class DistTensorflowRunner
所在模块 caicloud.clever.tensorflow.dist_base
。
模型训练的分布式运行器。
我们在模型任务中需要定义一个 DistTensorflowRunner
对象,指定模型定义函数、自定义初始化函数等参数。然后调用该对象的 run
方法来执行模型训练。
__init__
__init__(model_fn, gen_init_fn=None, after_train_hook=None)
创建一个分布式 TensorFlow 运行器实例。
参数说明:
model_fn
:我们实际业务模型定义的函数 ,不能为None
。该函数用于构造用户业务 TensorFlow 模型的 Graph,接收两个参数:
sync
:当前分布式运行是否采用参数同步更新模式。num_replicas
:分布式运行的 worker 副本个数。
该函数返回值是一个
ModelFnHandler
对象。gen_init_fn
:用于生成用户自定义初始化方法的函数,不接收任何参数,可选,默认为 None。gen_init_fn
返回值是一个函数对象,如果不需要执行自定义初始化,则返回 None。该函数对象接收一个tf.train.Scafford
对象和一个tf.Session
对象作为参数。关于自定义初始化有两个需要注意的地方:
- 当 TaaS 平台执行
gen_init_fn
返回的初始化函数时,TensorFlow graph 已经被 finalized,不再允许添加新的 Operation,所以需要在gen_init_fn
函数体中而不是在返回的函数中添加相关的 Operation(例如定义tf.train.Saver
)。 - TaaS 平台会优先从运行时指定的日志路径中查找最新的 checkpoint 文件,如果没有找到,则调用
gen_init_fn
返回的函数对象来进行模型初始化;否则,直接使用该 checkpoint 来初始化模型参数,而不调用gen_init_fn
返回的函数对象。
- 当 TaaS 平台执行
after_train_hook
:模型成功训练后钩子,可选,默认为 None。该钩子函数会在模型成功训练之后执行,该函数需要接收一个
tf.Session
对象作为参数。
返回值:
DistTensorflowRunner
对象
run
run(train_fn)
启动分布式模型训练。
参数说明:
train_fn
:用于执行每一轮训练的具体操作的函数。该函数需要接收两个参数:
session
:tf.Session
对象。num_global_step
:当前训练所处的训练轮次。
该函数返回一个
bool
值用于表示是否要终止模型训练。返回True
表示终止模型训练,返回False
表示继续模型训练。