模型评估
本节是在手写数字识别 MNIST 样例基础上进行实现的。
在机器学习的模型训练中,通常将数据集分为训练集、验证集和测试集。使用训练集来训练模型,在训练的过程中,通过验证集来验证模型的性能以防止模型过拟合,最后训练完成后通过测试集来评估模型的性能。
TaaS 平台支持模型评估的结果反馈机制。通过反馈机制,TaaS 平台能够收集到所有训练任务得到的模型的评估结果。然后通过基于项目维度的图表可视化训练模型的性能发展趋势,以此来评估项目的进展状况。
我们可以在 model_fn
方法返回的 ModelFnHandler
对象中指定 model_metric_ops
参数来提供模型评估的指标和计算方法。
注:TaaS 平台的模型评测目前只在模型训练完成之后执行一次。
我们修改手写数字识别 MNIST 样例 中提供的 model_fn
方法。
def model_fn(sync, num_replicas):
# 模型定义操作
...
# 定义模型评测(准确率)的计算方法
def accuracy_evalute_fn(session):
return session.run(_accuracy,
feed_dict={
input_images: _mnist.validation.images,
labels: _mnist.validation.labels})
# 模型评估的度量指标和指标计算方法的字典。
# 度量指标(例如这里的"accuracy")就是最终在 TaaS 平台展示的度量指标名称。
# 而 accuracy_evalute_fn 函数对象就是计算 "accuracy" 这个指标的方法。
# 该函数需要返回一个 float32 的值作为度量指标的结果。
model_metric_ops = {
"accuracy": accuracy_evalute_fn
}
# 需要通过 model_fn 函数返回值 ModelFnHandler 对象的 model_metric_ops 参数
# 告知 TaaS 平台模型评估的配置。
return dist_base.ModelFnHandler(
global_step=_global_step,
optimizer=optimizer,
model_metric_ops = model_metric_ops)
完整代码请参考 caicloud/tensorflow-tutorial。