分布式训练宽而浅的模型

时间:2017-08-22 08:48:13

标签: tensorflow google-cloud-ml

我正在开发一个非常宽而浅的计算图,在一台机器上具有相对较少的共享参数。我想让图表更宽,但内存不足。我的理解是,通过使用Distributed Tensorflow,可以使用tf.device上下文管理器在worker之间拆分图。然而,目前尚不清楚如何处理损失,这只能通过运行整个图表和训练操作来计算。

为这种模型训练参数的正确策略是什么?

1 个答案:

答案 0 :(得分:1)

TensorFlow基于数据流图的概念。您可以定义由变量和操作组成的图形,并且可以将所述变量和操作放在不同的服务器和/或设备上。当您调用session.Run时,您将数据传递到图表以及输入(在feed_dict中指定)和输出(在fetches参数中指定的{{1})之间的每个操作运行,无论这些操作位于何处。当然,跨服务器传递数据会产生通信开销,但这种开销通常是由于您可以让多个并发工作者同时执行计算这一事实。

简而言之,即使您将操作放在其他服务器上,您仍然可以计算完整图表上的损失。

这是一个关于大规模线性模型的教程:https://www.tensorflow.org/tutorials/linear

这是TensorFlow中分布式培训的教程: https://www.tensorflow.org/deploy/distributed