class ModelFnHandler
所在模块 caicloud.clever.tensorflow.dist_base。
ModelFnHandler 用于指明用户业务模型的一些信息,例如 global_step、优化器 optimizer、模型评估度量操作 model_metrics_ops 和模型导出配置 ModelExportSpec。
__init__
__init__(
global_step=None,
optimizer=None,
model_metrics_ops=None,
model_export_spec=None,
summary_op=_USE_DEFAULT)
创建一个 ModelFnHandler 对象。
参数说明:
global_step:None或者tf.Tensor对象。如果为
None,TaaS 平台运行框架会通过tf.train.get_global_step()方法自动获取,此时则需要用户在其提供的model_fn函数中将global_step变量命名为 'global_step' 或者将该变量添加到 Graph 的 GLOBAL_STEP 集合中。optimizer:None或者tf.train.SyncReplicasOptimizer对象。当采用参数同步更新模式时,必须通过
optimizer参数来反馈tf.train.SyncReplicasOptimizer对象,TaaS 平台运行框架需要通过该对象获取相对应的初始化操作。model_metrics_ops:模型评估度量的指标到具体计算方法的字典。例如{ "accuracy": compute_accuracy_fn}。其中compute_accuracy_fn是一个函数对象,该函数需要接收一个tf.Session对象作为参数,然后返回最终计算得的 accuracy 指标值。如果我们没有指定该参数,TaaS 平台运行框架将会通过 LOESES 集合获取计算模型损失 loss 值。
model_export_spec:模型导出配置ModelExportSpec对象。该参数为None时表示不导出模型。summary_op:计算并收集模型 Graph 中 Summary 信息的 Operation。默认情况下,将执行tf.summary.merge_all()来收集模型 Graph 中的所有 Summary 信息。将该参数设置为 None,将关闭自动计算 Summary 信息的功能。