当我们想要使用分布式TensorFlow时,我们将使用
创建一个参数服务器tf.train.Server.join()
但是,除了终止处理之外,我无法找到关闭服务器的任何方法。 join()的TensorFlow文档是
Blocks until the server has shut down.
This method currently blocks forever.
这对我来说非常麻烦,因为我想创建许多服务器进行计算,并在一切结束时关闭它们。
是否有可能的解决方案。
由于
答案 0 :(得分:12)
您可以使用session.run(dequeue_op)
而不是server.join()
按需使参数服务器进程死亡,并且当您希望此进程死亡时,让另一个进程将某些内容排入该队列。
因此,对于k
参数服务器分片,您可以创建具有唯一k
属性的shared_name
队列,并尝试从该队列中dequeue
。如果要关闭服务器,可以遍历所有队列并将enqueue
标记放到每个队列上。这将导致session.run
解除阻塞,Python进程将运行到最后并退出,从而关闭服务器。
下面是一个自包含的示例,其中包含2个分片: https://gist.github.com/yaroslavvb/82a5b5302449530ca5ff59df520c369e
(对于多工作者/多个分片示例,请参阅https://gist.github.com/yaroslavvb/ea1b1bae0a75c4aae593df7eca72d9ca)
import subprocess
import tensorflow as tf
import time
import sys
flags = tf.flags
flags.DEFINE_string("port1", "12222", "port of worker1")
flags.DEFINE_string("port2", "12223", "port of worker2")
flags.DEFINE_string("task", "", "internal use")
FLAGS = flags.FLAGS
# setup local cluster from flags
host = "127.0.0.1:"
cluster = {"worker": [host+FLAGS.port1, host+FLAGS.port2]}
clusterspec = tf.train.ClusterSpec(cluster).as_cluster_def()
if __name__=='__main__':
if not FLAGS.task: # start servers and run client
# launch distributed service
def runcmd(cmd): subprocess.Popen(cmd, shell=True, stderr=subprocess.STDOUT)
runcmd("python %s --task=0"%(sys.argv[0]))
runcmd("python %s --task=1"%(sys.argv[0]))
time.sleep(1)
# bring down distributed service
sess = tf.Session("grpc://"+host+FLAGS.port1)
queue0 = tf.FIFOQueue(1, tf.int32, shared_name="queue0")
queue1 = tf.FIFOQueue(1, tf.int32, shared_name="queue1")
with tf.device("/job:worker/task:0"):
add_op0 = tf.add(tf.ones(()), tf.ones(()))
with tf.device("/job:worker/task:1"):
add_op1 = tf.add(tf.ones(()), tf.ones(()))
print("Running computation on server 0")
print(sess.run(add_op0))
print("Running computation on server 1")
print(sess.run(add_op1))
print("Bringing down server 0")
sess.run(queue0.enqueue(1))
print("Bringing down server 1")
sess.run(queue1.enqueue(1))
else: # Launch TensorFlow server
server = tf.train.Server(clusterspec, config=None,
job_name="worker",
task_index=int(FLAGS.task))
print("Starting server "+FLAGS.task)
sess = tf.Session(server.target)
queue = tf.FIFOQueue(1, tf.int32, shared_name="queue"+FLAGS.task)
sess.run(queue.dequeue())
print("Terminating server"+FLAGS.task)
答案 1 :(得分:3)
目前还没有关闭TensorFlow gRPC服务器的干净方法。 可能shut down a gRPC server,但安全地执行此操作需要对所有正在进行的请求和响应缓冲区进行额外的内存管理,这需要大量额外的管道(最糟糕的类型) :异步共享内存管理...)一个没人请求的功能 - 直到现在!
在实践中,您应该能够将相同的tf.train.Server
对象用于许多不同的计算。如果这对您的用例不起作用,请随时open an GitHub issue并告诉我们有关您的用例的更多信息。
答案 2 :(得分:2)
此页面经常出现在Google上,因此我认为我会尝试改进Yaroslav's answer,为那些刚刚进入分布式Tensorflow的人提供一个更明确的答案。
df
通过使用此代码段替换代码的worker部分来扩展“规范”Distributed Tensorflow example非常简单:
df
category date rate
1 2011-01-01 0.50
2 2011-01-01 0.75
1 2011-02-01 0.50
2 2011-02-01 0.75
1 2011-03-01 1.00
2 2011-03-01 1.25
1 2011-04-01 1.00
2 2011-04-01 1.25
请注意,MonlatedTrainingSession版本在将所有工作人员连接在一起时似乎要慢得多。