原生 TensorFlow API 模型训练任务模板
我们在 github 上提供了支持原生 TensorFlow API 模型训练任务的代码模板 template-raw.py。
为了方便,这边也提供一份拷贝:
# coding=utf-8
from __future__ import print_function
import tensorflow as tf
# 导入 CaiCloud TaaS 平台任务框架的模块
from caicloud.clever.tensorflow import dist_base
from caicloud.clever.tensorflow import model_exporter
tf.app.flags.DEFINE_string("export_dir",
"/tmp/mnist/saved_model",
"model export directory path.")
tf.app.flags.DEFINE_string("checkpoint_dir",
"",
"model checkpoint directory path.")
FLAGS = tf.app.flags.FLAGS
_train_op = None
def model_fn(sync, num_replicas):
"""TensorFlow 模型定义函数。
在任务执行的时候调用该函数用于生成 TensorFlow 模型的计算图(tf.Graph)。
在函数中定义模型的前向推理算法、损失函数、优化器以及模型评估的指标和计算方法等信息。
参数:
- `sync`:当前是否采用参数同步更新模式。
- `num_replicas`:分布式 TensorFlow 的计算节点(worker)个数。
"""
global _train_op
# TODO:添加业务模型定义操作。
# global_step = ...
# _train_op = ...
# 添加模型评估配置:
# accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# def accuracy_evalute_fn(session):
# return session.run(accuracy, ...)
# model_metric_ops = {
# "accuracy": accuracy_evalute_fn
# }
# 定义模型导出配置
# model_export_spec = model_exporter.ModelExportSpec(
# export_dir=FLAGS.export_dir,
# input_tensors={"image": _input_images},
# output_tensors={"logits": logits})
# model_fn 函数需要返回 ModelFnHandler 对象告知 TaaS 平台所构建的模型的一些信息,
# 例如 global_step、优化器 Optimizer、模型评估指标以及模型导出的相关配置等等。
# 详细信息请参考 docs.caicloud.io。
return dist_base.ModelFnHandler(
global_step = global_step,
model_metric_ops = model_metric_ops,
model_export_spec = model_export_spec)
def train_fn(session, num_global_step):
"""模型训练的每一轮操作。
模型训练训练中的每一轮训练时的操作。
参数:
- `session`:tf.Session 对象;
- `num_global_step`:当前所处训练轮次。
"""
# TODO:添加业务模型训练操作。
# train_fn 函数返回一个 bool 值,用于告知 TaaS 平台是否要提前终止模型训练。
# 返回 True,表示终止训练;否则,TaaS 将继续下一轮训练。
# 例如,为了防止训练模型过拟合,在训练过程中定时使用验证数据评测模型效果。当模型效果
# 达到预期效果,便可以通过返回 True 来结束模型训练。
return False
def gen_init_fn():
"""获取自定义初始化函数。
有些情况下,我需要从某个事先训练好的 checkpoint 文件中加载模型的参数。此时,我们需要自
己实现使用 tf.Saver() 从该 checkpoint 中加载模型参数进行自定义初始化的函数。
注:如果不需要自定义初始化,可以不提供 gen_init_fn 实现,或者 gen_init_fn 返回 None。
"""
# TODO: 添加自己的处理逻辑
# 定义 tf.train.Saver 会修改 TensorFlow 的 Graph 结构,
# 而当 Base 框架调用自定义初始化函数 init_from_checkpoint 的时候,
# TensorFlow 模型的 Graph 结构已经变成 finalized,不再允许修改 Graph 结构。
# 所以,这个定义必须放在 init_from_checkpoint 函数外面。
saver = tf.train.Saver(tf.trainable_variables())
def init_from_checkpoint(scaffold, sess):
"""执行自定义初始化的函数。
TaaS 平台会优先从设置的日志保存路径中获取最新的 checkpoint 来 restore 模型参数,
如果日志保存路径中找不到 checkpoint 文件,才会调用本函数来进行模型初始化。
本函数必须接收两个参数:
- scafford: tf.train.Scaffold 对象;
- sess: tf.Session 对象。
"""
saver.restore(sess, checkpoint_path)
return init_from_checkpoint
def after_train_hook(session):
"""模型训练操作。
TaaS 在整个模型训练结束之后会调用该函数来进行相关的善后处理。
这些善后处理需要您基于业务需要来提供,例如模型测试等。
参数:
- `session`:tf.Session 对象。
"""
pass
if __name__ == '__main__':
# 定义分布式 TensorFlow 运行器 DistTensorflowRunner 对象。
distTfRunner = dist_base.DistTensorflowRunner(
model_fn = model_fn,
after_train_hook = after_train_hook,
gen_init_fn = gen_init_fn)
# 调用 DistTensorflowRunner 对象的 run 方法执行分布式模型训练,需要传递每轮模型训练的
# 操作实现函数 train_fn。
distTfRunner.run(train_fn)