MNIST 基本模型训练任务
本节将通过手写数字识别 MNIST 的样例代码来手把手教你如何开发在 TaaS 上执行的分布式 TensorFlow 模型训练代码。
本节我们将定义三个主要函数:
model_fn
:TensorFlow 模型的定义函数,TaaS 在运行的时候会调用该函数来生成模型的 Graph。train_fn
:执行具体每一轮模型训练的操作。after_train_hook
:模型训练完成后钩子,TaaS 会在该模型训练任务结束之后调用该函数。
下面我们提供一个能够在 TaaS 平台上分布式执行的手写识别体 MNIST 样例代码:
# coding=utf-8
import tensorflow as tf
# 导入 TaaS 平台的 caicloud.clever.tensorflow.dist_base 模块
from caicloud.clever.tensorflow import dist_base
# 使用 TensorFlow 提供的 API 来下载并导入 MNIST 数据。
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
_mnist = read_data_sets('/tmp/mnist-data', one_hot=True)
_input_images = None
_labels = None
_loss = None
_train_op = None
_global_step = None
_accuracy = None
def model_fn(sync, num_replicas):
"""模型定义函数 model_fn 定义 TensorFlow 计算图(tf.Graph)。
model_fn 函数会接收到两个参数:
- sync:bool 值,表示当前分布式运行是否采用参数同步更新,为 True 表示同步更新,否则表示异步更新。
- num_replicas:int 值,用于表示当前分布式运行的 worker 个数。
TaaS 在执行该任务的时候通过调用该函数来生成模型计算图。定义模型计算图中的计算节点必须要在 model_fn 中
提供,当分布式任务开始执行之后,模型计算图就会被 finalized,不再允许往计算图中添加新的计算节点。
"""
# 这些变量在后续的训练操作函数 train_fn 中会使用到,所以这里使用了 global 变量。
global _input_images, _loss, _labels, _train_op, _accuracy
global _mnist, _global_step
# 构建模型的前向传播算法。
# 本示例值只提供了两层的全连接神经网络。
_input_images = tf.placeholder(tf.float32, [None, 784], name='image')
W = tf.Variable(tf.zeros([784, 10]), name='weights')
tf.summary.histogram("weights", W)
b = tf.Variable(tf.zeros([10]), name='bias')
tf.summary.histogram("bias", b)
logits = tf.matmul(_input_images, W) + b
# 记录训练轮数的变量 global_step 在定义时需要命名为 'global_step' 或者
# 使用 tf.add_to_collection 函数将该变量加入计算图的 GLOBAL_STEPS 集合中。
# 因为 TaaS 平台需要根据该 global_step 变量来控制模型训练的循环和结束,如果
# TaaS 平台在运行的时候获取不到 global_step 变量,将会抛出运行时错误。
_global_step = tf.Variable(0, name='global_step', trainable=False)
# 定义了交叉熵损失
_labels = tf.placeholder(tf.float32, [None, 10], name='labels')
cross_entropy = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=_labels))
_loss = tf.reduce_mean(cross_entropy, name='loss')
tf.add_to_collection(tf.GraphKeys.LOSSES, _loss)
# 使用 AdagradOptimizer 来优化模型。
optimizer = tf.train.AdagradOptimizer(0.01);
_train_op = optimizer.minimize(cross_entropy, global_step=_global_step)
# 定义评估模型的正确率
correct_prediction = tf.equal(tf.argmax(logits, 1),
tf.argmax(_labels, 1))
_accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# model_fn 函数需要返回一个 dist_base.ModelFnHandler 对象,通过该对象告知 TaaS 平台下面信息:
# - global_step:None 或者模型训练的全局轮数变量 global_step。不过,当 `global_step` 为 `None`
# 或者没有传递 `global_step` 参数的时候,要求提供的 `model_fn` 函数中定义的 `global_step` 变量
# 必须命名为 'global_step' 或者将该变量添加到 Graph 的 GLOBAL_STEP 集合中。
# - optimizer:None 或者 tf.train.SyncReplicasOptimizer 对象。
# 如果模型只是使用了异步的参数更新方式,则 `Optimizer` 可以设置为 None 或者不提供。而如果
# `model_fn` 中支持了参数同步更新模式,则 `optimizer` 参数必须设置成
# `tf.train.SyncReplicasOptimizer` 对象,否则在运行时开启了同步更新模式会抛出运行错误。
return dist_base.ModelFnHandler(
global_step=_global_step,
optimizer=optimizer)
# 定义一个 local_step 记录在当前 worker 上训练的轮数。
_local_step = 0
def train_fn(session, num_global_step):
"""模型训练函数 `train_fn` 提供了模型每一轮训练的具体操作。
在每轮训练的时候,我们从 MNIST 的训练数据集中读取长度为 100 的批量数据,然后 feed 到模型的 `_input_images` 和 `_labels` 中,然后执行 `_train_op` 操作。
该方法接收两个参数:
- `session`:`tf.Session` 对象;
- `num_global_step`:训练当前所处的 `global_step` 轮数。
注:在函数 train_fn() 中只执行模型训练,此时模型的 Graph 已经 finalized,不再支持定义新的 Operation。
"""
# 使用 global 变量来访问 model_fn 函数中定义的模型 Tensor。
global _local_step, _train_op, _loss, _input_images, _labels, _mnist, _accuracy
start_time = time.time()
_local_step += 1
batch_xs, batch_ys = _mnist.train.next_batch(100)
feed_dict = {_input_images: batch_xs,
_labels: batch_ys}
_, loss_value, np_global_step = session.run(
[_train_op, _loss, _global_step],
feed_dict=feed_dict)
duration = time.time() - start_time
# 每隔 50 轮打印一次训练模型的损失
if _local_step % 50 == 0:
print('Step {0}: loss = {1:0.2f} ({2:0.3f} sec), global step: {3}.'.format(
_local_step, loss_value, duration, np_global_step))
# 每隔 1000 轮,使用 MNIST 验证数据来计算并打印训练模型的正确率。
if _local_step % 1000 == 0:
print("Accuracy for validation data: {0:0.3f}".format(
session.run(
_accuracy,
feed_dict={
_input_images: _mnist.validation.images,
_labels: _mnist.validation.labels})))
# train_fn 函数返回一个 bool 值,用于表达是否要终止模型训练。返回 True 表示要终止模型训练。
# 例如,为了防止训练模型过拟合,在训练过程中需要时不时在验证数据集上评估模型的性能,
# 当模型性能达到自己想要的预期,便可以返回 True 来中断模型训练。
return False
def after_train_hook(session):
"""模型训练后勾子函数。
该函数会在模型训练结束之后被调用来执行一些善后处理。
该函数接收一个参数:
- `session`:`tf.Session` 对象。
模型训练后钩子 `after_train_hook` 中使用 MNIST 测试数据集来计算训练完成的模型的正确率。
"""
print("Train done.")
# 此处使用 MNIST 的测试数据集来计算训练模型的正确率。
print("Accuracy for test data: {0:0.3f}".format(
session.run(
_accuracy,
feed_dict={
_input_images: _mnist.test.images,
_labels: _mnist.test.labels})))
if __name__ == '__main__':
# 生成分布式运行器。
# TaaS 平台提供了一个分布式运行器 DistTensorflowRunner 类。初始化一个分布式运行器的时候,
# 需要提供上面定义的三个函数:model_fn、after_train_hook 和 train_fn。
distTfRunner = dist_base.DistTensorflowRunner(
model_fn = model_fn,
after_train_hook = after_train_hook)
distTfRunner.run(train_fn)
完整代码可以参考 caicloud/tensorflow-tutorial。
本地测试
Caicloud TaaS 提供了 caicloud.tensorflow 用于本地进行开发调试。通过 pip 可以安装。
$ sudo pip install caicloud.tensorflow
安装成功后,我们在本地便可以直接通过运行自己实现的业务代码文件来测试。详细情况请参考环境准备中的本地开发测试包。
我们将上面实现的手写数字识别 MNIST 任务代码保存到文件 mnist-demo.py 中,然后本地运行。
$ TF_MAX_STEPS=2000 TF_LOGDIR=/tmp/mnist-log python mnist-demo.py
本地测试完成后,我们便可以将该代码上传到 TaaS 平台的数据存储中,然后发起分布式的模型训练任务了。