如何在分布式Tensorflow中添加新操作?

时间:2017-03-30 03:49:26

标签: python tensorflow

我正在尝试在分布式张量流中添加新操作。下面的代码是我如何创建一个可以添加两个数组的新运算符(遵循本教程:https://www.tensorflow.org/api_docs/python/tf/py_func)。

import tensorflow as tf
import numpy as np

array1 = np.array([[1, 2], [3, 4]], dtype=np.float32)
array2 = np.array([[5, 6], [7, 8]], dtype=np.float32)

def add(array1, array2):
    return array1 + array2

add_op = tf.py_func(add, [array1, array2], [tf.float32])


sess = tf.Session()
print sess.run(add_op)

然后我按照本教程(https://www.tensorflow.org/deploy/distributed)创建一个TensorFlow服务器集群,并尝试在其中添加一个运算符。我启动节点,使用以下代码和命令行。

#worker.py
import tensorflow as tf

# Configuration of cluster 
ps_hosts = [ "127.0.0.1:8887" ]
worker_hosts = [ "127.0.0.1:8888", "127.0.0.1:8889" ]
cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})      

tf.app.flags.DEFINE_string("job_name", "worker", "One of 'ps', 'worker'")
tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")

FLAGS = tf.app.flags.FLAGS

def main(_):
    server = tf.train.Server(cluster,
                             job_name=FLAGS.job_name,
                             task_index=FLAGS.task_index)
    server.join()

if __name__ == "__main__":
    tf.app.run()

命令行:

# On worker0
python worker.py --job_name=worker --task_index=0

# On ps0
python worker.py --job_name=ps --task_index=0

然后执行此文件以指定操作是否运行。

#master.py

import tensorflow as tf
import numpy as np

def add(array1, array2):
    return array1 + array2    

with tf.device("/job:ps/task:0"):
    array1 = np.array([[1, 2], [3, 4]], dtype=np.float32)
    array2 = np.array([[5, 6], [7, 8]], dtype=np.float32)

with tf.device("/job:worker/task:0"):
    add_op = tf.py_func(add, [array1, array2], [tf.float32])


with tf.Session("grpc://127.0.0.1:8887") as sess:
    print sess.run(add_op)

但是这段代码不起作用,以下是错误

Traceback (most recent call last):
  File "master.py", line 19, in <module>
    print sess.run(add_op)
  File "/Users/WakeUp/Repository/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 767, in run
    run_metadata_ptr)
  File "/Users/WakeUp/Repository/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 965, in _run
    feed_dict_string, options, run_metadata)
  File "/Users/WakeUp/Repository/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1015, in _do_run
    target_list, options, run_metadata)
  File "/Users/WakeUp/Repository/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1035, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InternalError: Failed to run py callback pyfunc_0: see error log.
     [[Node: PyFunc = PyFunc[Tin=[DT_FLOAT, DT_FLOAT], Tout=[DT_FLOAT], token="pyfunc_0", _device="/job:worker/replica:0/task:0/cpu:0"](PyFunc/input_0, PyFunc/input_1)]]
     [[Node: PyFunc_S1 = _Recv[client_terminated=false, recv_device="/job:ps/replica:0/task:0/cpu:0", send_device="/job:worker/replica:0/task:0/cpu:0", send_device_incarnation=-6409934414370243221, tensor_name="edge_8_PyFunc", tensor_type=DT_FLOAT, _device="/job:ps/replica:0/task:0/cpu:0"]()]]

Caused by op u'PyFunc', defined at:
  File "master.py", line 15, in <module>
    add_op = tf.py_func(add, [array1, array2], [tf.float32])
  File "/Users/WakeUp/Repository/tensorflow/lib/python2.7/site-packages/tensorflow/python/ops/script_ops.py", line 189, in py_func
    input=inp, token=token, Tout=Tout, name=name)
  File "/Users/WakeUp/Repository/tensorflow/lib/python2.7/site-packages/tensorflow/python/ops/gen_script_ops.py", line 40, in _py_func
    name=name)
  File "/Users/WakeUp/Repository/tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 763, in apply_op
    op_def=op_def)
  File "/Users/WakeUp/Repository/tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2327, in create_op
    original_op=self._default_original_op, op_def=op_def)
  File "/Users/WakeUp/Repository/tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1226, in __init__
    self._traceback = _extract_stack()

InternalError (see above for traceback): Failed to run py callback pyfunc_0: see error log.
     [[Node: PyFunc = PyFunc[Tin=[DT_FLOAT, DT_FLOAT], Tout=[DT_FLOAT], token="pyfunc_0", _device="/job:worker/replica:0/task:0/cpu:0"](PyFunc/input_0, PyFunc/input_1)]]
     [[Node: PyFunc_S1 = _Recv[client_terminated=false, recv_device="/job:ps/replica:0/task:0/cpu:0", send_device="/job:worker/replica:0/task:0/cpu:0", send_device_incarnation=-6409934414370243221, tensor_name="edge_8_PyFunc", tensor_type=DT_FLOAT, _device="/job:ps/replica:0/task:0/cpu:0"]()]]

根据N.B.在这个页面(https://www.tensorflow.org/api_docs/python/tf/py_func)中,我认为我没有违反py_func函数的任何规则。例如序列化模型或忘记运行tf.train.Servertf.device()。还有什么我没注意到的吗?

0 个答案:

没有答案