Tensorflow:在分布式培训中使用参数服务器

时间:2016-12-05 17:13:43

标签: python tensorflow

目前还不完全清楚参数服务器如何知道在分布式张量流训练中该做什么。

例如,在此SO question中,以下代码用于配置参数服务器和辅助任务:

if FLAGS.job_name == "ps":
    server.join()
elif FLAGS.job_name == "worker":
    ##some training code

server.join()如何表明给定的任务应该是参数服务器?参数是否为任务提供了一种默认行为?你还可以/应该告诉参数服务任务吗?

编辑:此SO question解决了我的一些疑问:"那里的逻辑确保将Variable对象均匀分配给充当参数服务器的工作人员。&#34 ;但参数服务器如何知道它是一个参数服务器? server.join()足够吗?

1 个答案:

答案 0 :(得分:10)

TL; DR: TensorFlow对“参数服务器”一无所知,但它支持在不同进程中跨多个设备运行图形。其中一些进程具有名称以"/job:ps"开头的设备,这些设备包含变量。工作人员推动了培训过程,当他们运行train_op时,他们将导致"/job:ps"设备上的工作发生,这将更新共享变量。

server.join()方法只是告诉TensorFlow阻塞并侦听请求,直到服务器关闭(这当前意味着它会永久阻塞,或直到你终止进程,因为当前没有实现干净关闭)。 / p>

在我之前回答的example中,PS任务是被动的,一切都由## some training code中的工作人员任务控制。如果将代码分割到多个设备上,TensorFlow将添加适当的通信,这将扩展到不同进程中的设备。 with tf.device(tf.train.replica_device_setter(...)):块告诉TensorFlow将每个变量放在不同的PS任务上,方法是将其设备设置为"/job:ps/task:{i}"(对于{i}的不同值,以循环方式选择)。

当您调用sess.run(train_op)时,TensorFlow将运行依赖于并更新变量的图表,并包含更新它们的操作。这部分计算将在"/job:ps"设备上进行,因此这些设备将充当参数服务器。