我将为单个节点运行创建Tensorflow图。但是后来如果我想在分布式环境中训练相同的模型图(在多个参数服务器之间划分变量并在n个工作者之间复制图形),我该怎么做?
我找到了一个名为tf.Graph.as_graph_def()
的东西来导出GraphDef原型,然后将图形导入为tf.import_graph_def()
。但这没效果。
代码:
import tensorflow as tf
graph = tf.Graph()
with graph.as_default():
x_place_holder = tf.placeholder(dtype=tf.float32, shape=[], name="xin")
y_place_holder = tf.placeholder(dtype=tf.float32, shape=[], name="yin")
m = tf.Variable(10.0, name="varm")
l = tf.Variable(20.0, name="varl")
Y = tf.multiply(m, x_place_holder, name="mulop")
X = tf.add(l, x_place_holder, name="addop")
cost = tf.abs(Y - X, name="cost")
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.5, name="optimizer").minimize(cost)
tf.reset_default_graph()
if FLAGS.job_name == "ps":
server.join()
elif FLAGS.job_name == "worker":
print(FLAGS.task_index, "task index")
with tf.device(tf.train.replica_device_setter(
worker_device="/job:worker/task:%d" % FLAGS.task_index,
cluster=cluster)):
tf.import_graph_def(graph.as_graph_def(),return_elements=["xin","yin","varm","varl","mulop","addop","cost","optimizer"])
堆栈跟踪:
Traceback (most recent call last):
File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1039, in _do_call
return fn(*args)
File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1017, in _run_fn
self._extend_graph()
File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1066, in _extend_graph
self._session, graph_def.SerializeToString(), status)
File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/contextlib.py", line 66, in __exit__
next(self.gen)
File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/framework/errors_impl.py", line 466, in raise_exception_on_not_ok_status
pywrap_tensorflow.TF_GetCode(status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot colocate nodes 'import/varl/read' and 'import/varl: Cannot merge devices with incompatible jobs: '/job:ps/task:1' and '/job:worker/task:1'
[[Node: import/varl/read = Identity[T=DT_FLOAT, _class=["loc:@import/varl"], _device="/job:worker/task:1"](import/varl)]]
或者Tensorflow是否还有其他方法可以做到这一点?
答案 0 :(得分:0)
从2017年6月开始不支持。要在分布式环境中训练模型,如果它包含在replica_device_setter中,则可以重用生成图形的python代码,而不是生成的图形本身。