模型自定义初始化

本节是在手写数字识别 MNIST 样例基础上进行实现的。

当训练一个非常复杂的 TensorFlow 模型时,每次都从头开始训练太耗时,我们往往会从预训练(pre-trained)的模型文件(checkpoint 文件)中加载模型的参数进行初始化操作,例如我们要训练 Inception 模型,如果从头开始训练,将会非常耗时,而我们可以通过从 google 提供的在 ImageNet 数据集上训练得到的 checkpoint 文件中加载模型参数来初始化,然后在我们的数据集上再进行二次训练即可。

TaaS 平台提供了模型参数自定义初始化的机制。我们只需要在定义分布式运行器 DistTensorflowRunner 对象中提供 gen_init_fn 参数即可。

def gen_init_fn():
    """获取自定义初始化函数。

    我们在自定义模型初始化方法 `gen_init_fn` 通过一个 `tf.Saver` 来从指定的目录中加载 
    checkpoint 文件来初始化模型参数。 
    """
    checkpoint_path = FLAGS.checkpoint_dir
    if checkpoint_path is None or checkpoint_path == "":
        return None

    if not tf.gfile.Exists(checkpoint_path):
        print('WARNING: checkpoint path {0} not exists.'.format(checkpoint_path))
        return None

    # 获取指定目录下最新的一个 checkpoint 文件路径。
    if tf.gfile.IsDirectory(checkpoint_path):
        checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)
    else:
        checkpoint_path = checkpoint_path
    print('warm-start from checkpoint {0}'.format(checkpoint_path))

    # 生成一个 tf.train.Saver 对象用于模型参数导入。
    # tf.trainable_variables 函数可以获取到当前默认模型计算图中的所有要训练的参数列表,
    # 如果需要加载命名不一样的参数的话,则需要自已按照 tf.train.Saver 的文档来进行处理。
    # 定义 tf.train.Saver 会修改 TensorFlow 的 Graph 结构,而当 Base 框架调用自定义初始
    # 化函数 init_from_checkpoint 的时候,TensorFlow 模型的 Graph 结构已经变成 finalized,
    # 不再允许修改 Graph 结构。所以,这个定义必须放在  init_from_checkpoint 函数外面。
    saver = tf.train.Saver(tf.trainable_variables())

    def init_from_checkpoint(scaffold, sess):
        """执行自定义初始化的函数。

        TaaS 平台会优先从设置的日志保存路径中获取最新的 checkpoint 来 restore 模型参数,
        如果日志保存路径中找不到 checkpoint 文件,才会调用本函数来进行模型初始化。
        本函数必须接收两个参数:
          - scafford: tf.train.Scaffold 对象;
          - sess: tf.Session 对象。
        """
        saver.restore(sess, checkpoint_path)
    return init_from_checkpoint

if __name__ == '__main__':
    # 在定义 DistTensorflowRunner 对象的时候,通过 gen_init_fn 参数指定自定义模型初始化方法。 
    distTfRunner = dist_base.DistTensorflowRunner(
        model_fn = model_fn,
        after_train_hook = after_train_hook,
        gen_init_fn = gen_init_fn)
    distTfRunner.run(train_fn)

完整代码请参考 caicloud/tensorflow-tutorial

results matching ""

    No results matching ""