MNIST 基本模型训练任务

本节将通过手写数字识别 MNIST 的样例代码来手把手教你如何开发在 TaaS 上执行的分布式 TensorFlow 模型训练代码。

本节我们将定义三个主要函数:

  • model_fn:TensorFlow 模型的定义函数,TaaS 在运行的时候会调用该函数来生成模型的 Graph。
  • train_fn:执行具体每一轮模型训练的操作。
  • after_train_hook:模型训练完成后钩子,TaaS 会在该模型训练任务结束之后调用该函数。

下面我们提供一个能够在 TaaS 平台上分布式执行的手写识别体 MNIST 样例代码:

# coding=utf-8
import tensorflow as tf

# 导入 TaaS 平台的 caicloud.clever.tensorflow.dist_base 模块
from caicloud.clever.tensorflow import dist_base

# 使用 TensorFlow 提供的 API 来下载并导入 MNIST 数据。
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
_mnist = read_data_sets('/tmp/mnist-data', one_hot=True)

_input_images = None
_labels = None
_loss = None
_train_op = None
_global_step = None
_accuracy = None

def model_fn(sync, num_replicas):
    """模型定义函数 model_fn 定义 TensorFlow 计算图(tf.Graph)。

    model_fn 函数会接收到两个参数:
    - sync:bool 值,表示当前分布式运行是否采用参数同步更新,为 True 表示同步更新,否则表示异步更新。
    - num_replicas:int 值,用于表示当前分布式运行的 worker 个数。

    TaaS 在执行该任务的时候通过调用该函数来生成模型计算图。定义模型计算图中的计算节点必须要在 model_fn 中
    提供,当分布式任务开始执行之后,模型计算图就会被 finalized,不再允许往计算图中添加新的计算节点。
    """

    # 这些变量在后续的训练操作函数 train_fn 中会使用到,所以这里使用了 global 变量。
    global _input_images, _loss, _labels, _train_op, _accuracy
    global _mnist, _global_step

    # 构建模型的前向传播算法。
    # 本示例值只提供了两层的全连接神经网络。
    _input_images = tf.placeholder(tf.float32, [None, 784], name='image')
    W = tf.Variable(tf.zeros([784, 10]), name='weights')
    tf.summary.histogram("weights", W)
    b = tf.Variable(tf.zeros([10]), name='bias')
    tf.summary.histogram("bias", b)
    logits = tf.matmul(_input_images, W) + b

    # 记录训练轮数的变量 global_step 在定义时需要命名为 'global_step' 或者
    # 使用 tf.add_to_collection 函数将该变量加入计算图的 GLOBAL_STEPS 集合中。
    # 因为 TaaS 平台需要根据该 global_step 变量来控制模型训练的循环和结束,如果
    # TaaS 平台在运行的时候获取不到 global_step 变量,将会抛出运行时错误。
    _global_step = tf.Variable(0, name='global_step', trainable=False)

    # 定义了交叉熵损失
    _labels = tf.placeholder(tf.float32, [None, 10], name='labels')
    cross_entropy = tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=_labels))
    _loss = tf.reduce_mean(cross_entropy, name='loss')
    tf.add_to_collection(tf.GraphKeys.LOSSES, _loss)

    # 使用 AdagradOptimizer 来优化模型。
    optimizer = tf.train.AdagradOptimizer(0.01);

    _train_op = optimizer.minimize(cross_entropy, global_step=_global_step)

    # 定义评估模型的正确率
    correct_prediction = tf.equal(tf.argmax(logits, 1),
                                  tf.argmax(_labels, 1))
    _accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    # model_fn 函数需要返回一个 dist_base.ModelFnHandler 对象,通过该对象告知 TaaS 平台下面信息:
    # - global_step:None 或者模型训练的全局轮数变量 global_step。不过,当 `global_step` 为 `None` 
    #   或者没有传递 `global_step` 参数的时候,要求提供的 `model_fn` 函数中定义的 `global_step` 变量
    #   必须命名为 'global_step' 或者将该变量添加到 Graph 的 GLOBAL_STEP 集合中。
    # - optimizer:None 或者 tf.train.SyncReplicasOptimizer 对象。
    #   如果模型只是使用了异步的参数更新方式,则 `Optimizer` 可以设置为 None 或者不提供。而如果 
    #   `model_fn` 中支持了参数同步更新模式,则 `optimizer` 参数必须设置成 
    #   `tf.train.SyncReplicasOptimizer` 对象,否则在运行时开启了同步更新模式会抛出运行错误。    
    return dist_base.ModelFnHandler(
        global_step=_global_step,
        optimizer=optimizer)

# 定义一个 local_step 记录在当前 worker 上训练的轮数。
_local_step = 0

def train_fn(session, num_global_step):
    """模型训练函数 `train_fn` 提供了模型每一轮训练的具体操作。

    在每轮训练的时候,我们从 MNIST 的训练数据集中读取长度为 100 的批量数据,然后 feed 到模型的 `_input_images` 和 `_labels` 中,然后执行 `_train_op` 操作。
    该方法接收两个参数:
    - `session`:`tf.Session` 对象;
    - `num_global_step`:训练当前所处的 `global_step` 轮数。    

    注:在函数 train_fn() 中只执行模型训练,此时模型的 Graph 已经 finalized,不再支持定义新的 Operation。
    """

    # 使用 global 变量来访问 model_fn 函数中定义的模型 Tensor。
    global _local_step, _train_op, _loss, _input_images, _labels, _mnist, _accuracy

    start_time = time.time()
    _local_step += 1
    batch_xs, batch_ys = _mnist.train.next_batch(100)
    feed_dict = {_input_images: batch_xs,
                 _labels: batch_ys}
    _, loss_value, np_global_step = session.run(
        [_train_op, _loss, _global_step],
        feed_dict=feed_dict)
    duration = time.time() - start_time

    # 每隔 50 轮打印一次训练模型的损失
    if _local_step % 50 == 0:
        print('Step {0}: loss = {1:0.2f} ({2:0.3f} sec), global step: {3}.'.format(
            _local_step, loss_value, duration, np_global_step))

    # 每隔 1000 轮,使用 MNIST 验证数据来计算并打印训练模型的正确率。
    if _local_step % 1000 == 0:
        print("Accuracy for validation data: {0:0.3f}".format(
            session.run(
                _accuracy,
                feed_dict={
                    _input_images: _mnist.validation.images,
                    _labels: _mnist.validation.labels})))

    # train_fn 函数返回一个 bool 值,用于表达是否要终止模型训练。返回 True 表示要终止模型训练。
    # 例如,为了防止训练模型过拟合,在训练过程中需要时不时在验证数据集上评估模型的性能,
    # 当模型性能达到自己想要的预期,便可以返回 True 来中断模型训练。
    return False

def after_train_hook(session):
    """模型训练后勾子函数。

    该函数会在模型训练结束之后被调用来执行一些善后处理。
    该函数接收一个参数:
      - `session`:`tf.Session` 对象。 
    模型训练后钩子 `after_train_hook` 中使用 MNIST 测试数据集来计算训练完成的模型的正确率。
    """
    print("Train done.")

    # 此处使用 MNIST 的测试数据集来计算训练模型的正确率。
    print("Accuracy for test data: {0:0.3f}".format(
        session.run(
            _accuracy,
            feed_dict={
                _input_images: _mnist.test.images,
                _labels: _mnist.test.labels})))

if __name__ == '__main__':
    # 生成分布式运行器。
    # TaaS 平台提供了一个分布式运行器 DistTensorflowRunner 类。初始化一个分布式运行器的时候,
    # 需要提供上面定义的三个函数:model_fn、after_train_hook 和 train_fn。
    distTfRunner = dist_base.DistTensorflowRunner(
        model_fn = model_fn,
        after_train_hook = after_train_hook)
    distTfRunner.run(train_fn)

完整代码可以参考 caicloud/tensorflow-tutorial

本地测试

Caicloud TaaS 提供了 caicloud.tensorflow 用于本地进行开发调试。通过 pip 可以安装。

$ sudo pip install caicloud.tensorflow

安装成功后,我们在本地便可以直接通过运行自己实现的业务代码文件来测试。详细情况请参考环境准备中的本地开发测试包

我们将上面实现的手写数字识别 MNIST 任务代码保存到文件 mnist-demo.py 中,然后本地运行。

$ TF_MAX_STEPS=2000 TF_LOGDIR=/tmp/mnist-log python mnist-demo.py

本地测试完成后,我们便可以将该代码上传到 TaaS 平台的数据存储中,然后发起分布式的模型训练任务了。

results matching ""

    No results matching ""