模型评估

本节是在手写数字识别 MNIST 样例基础上进行实现的。

在机器学习的模型训练中,通常将数据集分为训练集、验证集和测试集。使用训练集来训练模型,在训练的过程中,通过验证集来验证模型的性能以防止模型过拟合,最后训练完成后通过测试集来评估模型的性能。

TaaS 平台支持模型评估的结果反馈机制。通过反馈机制,TaaS 平台能够收集到所有训练任务得到的模型的评估结果。然后通过基于项目维度的图表可视化训练模型的性能发展趋势,以此来评估项目的进展状况。

我们可以在 model_fn 方法返回的 ModelFnHandler 对象中指定 model_metric_ops 参数来提供模型评估的指标和计算方法。

注:TaaS 平台的模型评测目前只在模型训练完成之后执行一次。

我们修改手写数字识别 MNIST 样例 中提供的 model_fn 方法。

def model_fn(sync, num_replicas):
    # 模型定义操作
    ...

    # 定义模型评测(准确率)的计算方法
    def accuracy_evalute_fn(session):
        return session.run(_accuracy,
                           feed_dict={
                               input_images: _mnist.validation.images,
                               labels: _mnist.validation.labels})

    # 模型评估的度量指标和指标计算方法的字典。
    # 度量指标(例如这里的"accuracy")就是最终在 TaaS 平台展示的度量指标名称。 
    # 而 accuracy_evalute_fn 函数对象就是计算 "accuracy" 这个指标的方法。
    # 该函数需要返回一个 float32 的值作为度量指标的结果。
    model_metric_ops = {
        "accuracy": accuracy_evalute_fn
    }

    # 需要通过 model_fn 函数返回值 ModelFnHandler 对象的 model_metric_ops 参数
    # 告知 TaaS 平台模型评估的配置。
    return dist_base.ModelFnHandler(
        global_step=_global_step,
        optimizer=optimizer,
        model_metric_ops = model_metric_ops)

完整代码请参考 caicloud/tensorflow-tutorial

results matching ""

    No results matching ""