class ModelFnHandler
所在模块 caicloud.clever.tensorflow.dist_base
。
ModelFnHandler
用于指明用户业务模型的一些信息,例如 global_step
、优化器 optimizer
、模型评估度量操作 model_metrics_ops
和模型导出配置 ModelExportSpec
。
__init__
__init__(
global_step=None,
optimizer=None,
model_metrics_ops=None,
model_export_spec=None,
summary_op=_USE_DEFAULT)
创建一个 ModelFnHandler
对象。
参数说明:
global_step
:None
或者tf.Tensor
对象。如果为
None
,TaaS 平台运行框架会通过tf.train.get_global_step()
方法自动获取,此时则需要用户在其提供的model_fn
函数中将global_step
变量命名为 'global_step' 或者将该变量添加到 Graph 的 GLOBAL_STEP 集合中。optimizer
:None
或者tf.train.SyncReplicasOptimizer
对象。当采用参数同步更新模式时,必须通过
optimizer
参数来反馈tf.train.SyncReplicasOptimizer
对象,TaaS 平台运行框架需要通过该对象获取相对应的初始化操作。model_metrics_ops
:模型评估度量的指标到具体计算方法的字典。例如{ "accuracy": compute_accuracy_fn}
。其中compute_accuracy_fn
是一个函数对象,该函数需要接收一个tf.Session
对象作为参数,然后返回最终计算得的 accuracy 指标值。如果我们没有指定该参数,TaaS 平台运行框架将会通过 LOESES 集合获取计算模型损失 loss 值。
model_export_spec
:模型导出配置ModelExportSpec
对象。该参数为None
时表示不导出模型。summary_op
:计算并收集模型 Graph 中 Summary 信息的 Operation。默认情况下,将执行tf.summary.merge_all()
来收集模型 Graph 中的所有 Summary 信息。将该参数设置为 None,将关闭自动计算 Summary 信息的功能。