模型参数同步更新

关于同步和异步更新

分布式 TensorFlow 集群由参数服务器(Parameter Server,简称 PS)和计算服务器(简称 worker)组成。

PNG

PS 主要用于分片保存模型参数,而 worker 在每轮训练开始之前会从 PS 中获取当前模型参数,然后进行训练,训练结束后再去更新模型的参数。于是不同 worker 之间更新 PS 中的模型参数的交互方式就涉及到同步更新或者异步更新的问题。下图中上下半部分分别展示了同步更新和异步更新的处理过程。

PNG

异步更新:在异步更新模式下,worker 之间没有相互依赖。每个 worker 在每轮训练开始前从 PS 获取模型参数,读取训练数据,进行训练,训练结束后便立即应用梯度来更新 PS 上的模型参数。异步更新模式下,每个 worker 训练结束就更新模型参数,然后继续下一轮更新,所以会训练得比较快,但收敛相对较慢

同步更新:在同步更新模式下,TensorFlow 在每轮训练的时候需要汇总所有 worker 训练得到的梯度值,然后取平均值来更新 PS 上的模型参数。 同步更新的优点是收敛快,但worker 节点性能可能存在差异,会导致出现等待的情况出现,训练相对较慢

关于 tf.train.SyncReplicasOptimizer

TensorFlow API 提供的 Optimizer 都是采用异步更新模式将分布式集群的 worker 节点上计算得到的梯度值应用到 PS 节点。另外,TensorFlow 提供了 tf.train.SyncReplicasOptimizer 来专门处理模型参数的同步更新。

tf.train.SyncReplicasOptimizer 会为每个要训练的模型参数提供一个汇总梯度的 accumulator,在每轮训练的时候收集来自所有 worker 的梯度,然后取平均值来更新相对应的模型参数。

MNIST 样例代码

我们将基本模型训练中的 model_fn 函数进行了下面调整,定义了一个 tf.train.SyncReplicasOptimizer 对象来封装原始的 optimizer 对象来提供同步更新模式。具体实现代码如下所示(其他代码没有变化,此处就省略了):

注:同步更新需要对 tf.train.SyncReplicasOptimizer 对象进行特殊的初始化,所以需要在 model_fn 返回的 ModelFnHandler 对象中提供 tf.train.SyncReplicasOptimizer 对象。

# coding=utf-8
import tensorflow as tf

# ...

def model_fn(sync, num_replicas):
    # ...

    # 使用 AdagradOptimizer 来优化模型。
    # 如果分布式执行采用参数同步更新模式,则需要定义 tf.train.SyncReplicasOptimizer 
    # 对象,并且在 model_fn 函数返回的 ModelFnHandler 中需要提供该  tf.train.SyncReplicasOptimizer 
    # 对象,因为 TaaS 平台需要通过 tf.train.SyncReplicasOptimizer 对象获取同步更新的一
    # 些初始化操作。如果采用同步更新模式时,没有提供 tf.train.SyncReplicasOptimizer 
    # 对象,TaaS 平台会抛出运行时错误。 
    optimizer = tf.train.AdagradOptimizer(0.01);
    if sync:
        optimizer = tf.train.SyncReplicasOptimizer(
            optimizer,
            replicas_to_aggregate=num_replicas,
            total_num_replicas=num_replicas,
            name="mnist_sync_replicas")

    _train_op = optimizer.minimize(cross_entropy, global_step=_global_step)

    # ...

    return dist_base.ModelFnHandler(
        global_step=_global_step,
        optimizer=optimizer)

完整代码可以参考 caicloud/tensorflow-tutorial

results matching ""

    No results matching ""