class ModelFnHandler

所在模块 caicloud.clever.tensorflow.dist_base

ModelFnHandler 用于指明用户业务模型的一些信息,例如 global_step、优化器 optimizer、模型评估度量操作 model_metrics_ops 和模型导出配置 ModelExportSpec

__init__

__init__(
    global_step=None,
    optimizer=None,
    model_metrics_ops=None,
    model_export_spec=None,
    summary_op=_USE_DEFAULT)

创建一个 ModelFnHandler 对象。

参数说明:

  • global_stepNone 或者 tf.Tensor 对象。

    如果为 None,TaaS 平台运行框架会通过 tf.train.get_global_step() 方法自动获取,此时则需要用户在其提供的 model_fn 函数中将 global_step 变量命名为 'global_step' 或者将该变量添加到 Graph 的 GLOBAL_STEP 集合中。

  • optimizerNone 或者 tf.train.SyncReplicasOptimizer 对象。

    当采用参数同步更新模式时,必须通过 optimizer 参数来反馈 tf.train.SyncReplicasOptimizer 对象,TaaS 平台运行框架需要通过该对象获取相对应的初始化操作。

  • model_metrics_ops:模型评估度量的指标到具体计算方法的字典。例如 { "accuracy": compute_accuracy_fn}。其中 compute_accuracy_fn 是一个函数对象,该函数需要接收一个 tf.Session 对象作为参数,然后返回最终计算得的 accuracy 指标值。

    如果我们没有指定该参数,TaaS 平台运行框架将会通过 LOESES 集合获取计算模型损失 loss 值。

  • model_export_spec:模型导出配置 ModelExportSpec 对象。该参数为 None 时表示不导出模型。

  • summary_op:计算并收集模型 Graph 中 Summary 信息的 Operation。默认情况下,将执行 tf.summary.merge_all() 来收集模型 Graph 中的所有 Summary 信息。将该参数设置为 None,将关闭自动计算 Summary 信息的功能。

results matching ""

    No results matching ""