使用tf.estimator进行分布式培训,从而产生更多培训步骤

时间:2017-08-31 20:57:22

标签: python tensorflow google-cloud-ml-engine

我正在试验Cloud ML Engine上的分布式培训选项,并观察了一些奇特的结果。我基本上改变了人口普查自定义估算器示例,以包含一个略有不同的模型,并将我的损失函数更改为AdamOptimizer作为唯一真正的更改。基于这个thread,我的理解是任何分布式培训都应该是数据并行异步培训,这将建议"如果您在10个工作节点中分配10,000个批次,则每个节点大约可以处理1000个批次。 #34;在我的实验中,我有大约650k的训练样例,我正在运行以下实验,一个批次大小为128的1个纪元。鉴于650k训练样例和128个批量大小,我希望在一个大约5.1k步骤时代。以下是我看到的针对不同--scale-tier

的表现

未分发

  • 基本:8步/秒,5.1k步,11分钟挂壁时间
  • BASIC_GPU:24步/秒,5.1k步,3.5分钟待机时间

分布式

  • STANDARD_1:14.5步/秒 - 26k步(26k * 128 = ~3.3M,这比实际数据中的训练样本多),29分钟的停留时间

  • CUSTOM - 5个complex_model_m工作人员,2个large_model参数服务器:27步/秒,31k步(128 * 31k = ~3.9M,这比实际数据中的650k训练样本多),待机时间

我的期望是基于文章的数据平行是分布式培训会将批次分成所有工人,所以如果我有5个批次的5个工人,那么每个工人将执行~1,000批次。然而,我观察到的实际行为似乎更接近于自己执行1个时代的5名工人中的每一个。在分布式设置中进行训练时,在一个时代中采用的步数是训练样例的6倍 - 我知道步骤的真正定义是每次更新渐变时,但我对数据并行训练的理解是这只会拆分批次,所以应该有相同数量的梯度更新 - 是否有任何理由为什么会出现这种行为?在数据并行异步训练分布式环境中需要更多的训练步骤是否有意义?任何人都可以解释我观察到的行为吗?

2 个答案:

答案 0 :(得分:9)

之前的答案在解释性能瓶颈方面做得很好。让我解释一下“epochs”以及TensorFlow如何处理数据集。

分布式培训在TensorFlow中的工作方式是每个工作人员独立地遍历整个数据集。一种常见的误解是,训练集是在工人之间划分的,但事实并非如此。

在具有队列的典型设置中(参见this tutorial),每个工作人员创建自己的队列会发生什么。该队列充满了所有训练文件的所有文件名列表(通常列表被洗牌,每次队列耗尽,它都会被重新填充和重新洗牌)。每个文件都是逐个实例读取的,数据被解析,预处理,然后被送入另一个队列,在那里对实例进行洗牌和批处理。读取任何文件的最后一个实例后,将从文件名队列中弹出下一个文件名。如果没有更多要弹出的文件,则“epoch”已完成。

这里重要的一点是,所有这些队列都默认为 local - 不共享。因此,每个工作人员都独立地重复相同的工作 - 使用所有文件创建队列并迭代整个数据集。那么,一个完整的纪元大致等于完整数据集中的实例数*工人的数量。 (我不确定你的standard_1结果,但是CUSTOM的结果意味着你有你的主人+5名工人= 6名工人* 650K例子*(1批/ 128例)= 31K步骤。

仅供参考,不鼓励使用时代参数化分布式培训,因为它过于混乱,一般来说甚至可能存在问题。坚持使用max_steps。

请注意,作为TensorFlow设计的结果,“批量大小”表示每个工作人员的批量大小 。但是每个工作人员将以大致相同的速率向参数服务器发送更新,因此在大致相当于处理一个“批处理”的时间的时间段内,参数发生的更新次数大致为{{ 1}} * batch_size。这就是我们所说的有效批量大小。这反过来会产生一些实际后果:

  1. 您倾向于使用较小的batch_sizes,尤其是如果您拥有大量工作人员,以便有效批量大小保持理智。
  2. 随着工人数量的增加,您的有效批量增加,因此您需要降低学习率,至少在使用“香草”随机梯度下降时。具有自适应学习率的优化器(如Adagrad,Adam等)往往对初始学习率很稳健,但如果添加足够的工作者,您仍可能需要调整学习率。
  3. 您可能想知道为什么TensorFlow以这种方式处理训练数据。这是因为在分布式系统中,您不能依赖于速度相同的机器,甚至根本不可靠。如果您将训练数据划分为不相交的集合,这些集合将发送给每个工作人员,然后一台或多台计算机相对于另一台计算机速度较慢,或者网络发生故障等等,您的培训过程将看到来自“快速”的数据/可靠的工人比“慢”/不可靠的工人更频繁。这会使结果偏向那些情况(或者在极端情况下,将它们全部忽略)。

答案 1 :(得分:2)

有两种类型的服务器:

  • 工人:执行图表计算
  • 参数server:存储参数,以便所有工作人员可以共享和更新参数

您可以在参数服务器中设置瓶颈。如果模型非常大并且您的参数服务器很少,则需要在工作服务器和参数服务器之间进行更多通信。例如,如果您有2个参数服务器,则您将在一个服务器中拥有该模型的一半参数,而在另一个服务器中拥有另一半参数。如果你有很多工人,他们必须得到并向不同的工人发送大量参数,工人将有很大的滞后。如果增加参数服务器的数量,滞后将会减少,因为每个参数服务器和工作人员之间的通信较少。

由于您的批量大小为128(非常大),并且您只能在CPU中执行每秒8步,我认为您的模型速度太快,以至于共享服务器的参数比运行一次迭代需要更多时间。所以你也可以尝试增加批量。