模型自定义初始化
本节是在手写数字识别 MNIST 样例基础上进行实现的。
当训练一个非常复杂的 TensorFlow 模型时,每次都从头开始训练太耗时,我们往往会从预训练(pre-trained)的模型文件(checkpoint 文件)中加载模型的参数进行初始化操作,例如我们要训练 Inception 模型,如果从头开始训练,将会非常耗时,而我们可以通过从 google 提供的在 ImageNet 数据集上训练得到的 checkpoint 文件中加载模型参数来初始化,然后在我们的数据集上再进行二次训练即可。
TaaS 平台提供了模型参数自定义初始化的机制。我们只需要在定义分布式运行器 DistTensorflowRunner
对象中提供 gen_init_fn
参数即可。
def gen_init_fn():
"""获取自定义初始化函数。
我们在自定义模型初始化方法 `gen_init_fn` 通过一个 `tf.Saver` 来从指定的目录中加载
checkpoint 文件来初始化模型参数。
"""
checkpoint_path = FLAGS.checkpoint_dir
if checkpoint_path is None or checkpoint_path == "":
return None
if not tf.gfile.Exists(checkpoint_path):
print('WARNING: checkpoint path {0} not exists.'.format(checkpoint_path))
return None
# 获取指定目录下最新的一个 checkpoint 文件路径。
if tf.gfile.IsDirectory(checkpoint_path):
checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)
else:
checkpoint_path = checkpoint_path
print('warm-start from checkpoint {0}'.format(checkpoint_path))
# 生成一个 tf.train.Saver 对象用于模型参数导入。
# tf.trainable_variables 函数可以获取到当前默认模型计算图中的所有要训练的参数列表,
# 如果需要加载命名不一样的参数的话,则需要自已按照 tf.train.Saver 的文档来进行处理。
# 定义 tf.train.Saver 会修改 TensorFlow 的 Graph 结构,而当 Base 框架调用自定义初始
# 化函数 init_from_checkpoint 的时候,TensorFlow 模型的 Graph 结构已经变成 finalized,
# 不再允许修改 Graph 结构。所以,这个定义必须放在 init_from_checkpoint 函数外面。
saver = tf.train.Saver(tf.trainable_variables())
def init_from_checkpoint(scaffold, sess):
"""执行自定义初始化的函数。
TaaS 平台会优先从设置的日志保存路径中获取最新的 checkpoint 来 restore 模型参数,
如果日志保存路径中找不到 checkpoint 文件,才会调用本函数来进行模型初始化。
本函数必须接收两个参数:
- scafford: tf.train.Scaffold 对象;
- sess: tf.Session 对象。
"""
saver.restore(sess, checkpoint_path)
return init_from_checkpoint
if __name__ == '__main__':
# 在定义 DistTensorflowRunner 对象的时候,通过 gen_init_fn 参数指定自定义模型初始化方法。
distTfRunner = dist_base.DistTensorflowRunner(
model_fn = model_fn,
after_train_hook = after_train_hook,
gen_init_fn = gen_init_fn)
distTfRunner.run(train_fn)
完整代码请参考 caicloud/tensorflow-tutorial。