我正在尝试在分布式张量流中添加新操作。下面的代码是我如何创建一个可以添加两个数组的新运算符(遵循本教程:https://www.tensorflow.org/api_docs/python/tf/py_func)。
import tensorflow as tf
import numpy as np
array1 = np.array([[1, 2], [3, 4]], dtype=np.float32)
array2 = np.array([[5, 6], [7, 8]], dtype=np.float32)
def add(array1, array2):
return array1 + array2
add_op = tf.py_func(add, [array1, array2], [tf.float32])
sess = tf.Session()
print sess.run(add_op)
然后我按照本教程(https://www.tensorflow.org/deploy/distributed)创建一个TensorFlow服务器集群,并尝试在其中添加一个运算符。我启动节点,使用以下代码和命令行。
#worker.py
import tensorflow as tf
# Configuration of cluster
ps_hosts = [ "127.0.0.1:8887" ]
worker_hosts = [ "127.0.0.1:8888", "127.0.0.1:8889" ]
cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
tf.app.flags.DEFINE_string("job_name", "worker", "One of 'ps', 'worker'")
tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")
FLAGS = tf.app.flags.FLAGS
def main(_):
server = tf.train.Server(cluster,
job_name=FLAGS.job_name,
task_index=FLAGS.task_index)
server.join()
if __name__ == "__main__":
tf.app.run()
命令行:
# On worker0
python worker.py --job_name=worker --task_index=0
# On ps0
python worker.py --job_name=ps --task_index=0
然后执行此文件以指定操作是否运行。
#master.py
import tensorflow as tf
import numpy as np
def add(array1, array2):
return array1 + array2
with tf.device("/job:ps/task:0"):
array1 = np.array([[1, 2], [3, 4]], dtype=np.float32)
array2 = np.array([[5, 6], [7, 8]], dtype=np.float32)
with tf.device("/job:worker/task:0"):
add_op = tf.py_func(add, [array1, array2], [tf.float32])
with tf.Session("grpc://127.0.0.1:8887") as sess:
print sess.run(add_op)
但是这段代码不起作用,以下是错误
Traceback (most recent call last):
File "master.py", line 19, in <module>
print sess.run(add_op)
File "/Users/WakeUp/Repository/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 767, in run
run_metadata_ptr)
File "/Users/WakeUp/Repository/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 965, in _run
feed_dict_string, options, run_metadata)
File "/Users/WakeUp/Repository/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1015, in _do_run
target_list, options, run_metadata)
File "/Users/WakeUp/Repository/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1035, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InternalError: Failed to run py callback pyfunc_0: see error log.
[[Node: PyFunc = PyFunc[Tin=[DT_FLOAT, DT_FLOAT], Tout=[DT_FLOAT], token="pyfunc_0", _device="/job:worker/replica:0/task:0/cpu:0"](PyFunc/input_0, PyFunc/input_1)]]
[[Node: PyFunc_S1 = _Recv[client_terminated=false, recv_device="/job:ps/replica:0/task:0/cpu:0", send_device="/job:worker/replica:0/task:0/cpu:0", send_device_incarnation=-6409934414370243221, tensor_name="edge_8_PyFunc", tensor_type=DT_FLOAT, _device="/job:ps/replica:0/task:0/cpu:0"]()]]
Caused by op u'PyFunc', defined at:
File "master.py", line 15, in <module>
add_op = tf.py_func(add, [array1, array2], [tf.float32])
File "/Users/WakeUp/Repository/tensorflow/lib/python2.7/site-packages/tensorflow/python/ops/script_ops.py", line 189, in py_func
input=inp, token=token, Tout=Tout, name=name)
File "/Users/WakeUp/Repository/tensorflow/lib/python2.7/site-packages/tensorflow/python/ops/gen_script_ops.py", line 40, in _py_func
name=name)
File "/Users/WakeUp/Repository/tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 763, in apply_op
op_def=op_def)
File "/Users/WakeUp/Repository/tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2327, in create_op
original_op=self._default_original_op, op_def=op_def)
File "/Users/WakeUp/Repository/tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1226, in __init__
self._traceback = _extract_stack()
InternalError (see above for traceback): Failed to run py callback pyfunc_0: see error log.
[[Node: PyFunc = PyFunc[Tin=[DT_FLOAT, DT_FLOAT], Tout=[DT_FLOAT], token="pyfunc_0", _device="/job:worker/replica:0/task:0/cpu:0"](PyFunc/input_0, PyFunc/input_1)]]
[[Node: PyFunc_S1 = _Recv[client_terminated=false, recv_device="/job:ps/replica:0/task:0/cpu:0", send_device="/job:worker/replica:0/task:0/cpu:0", send_device_incarnation=-6409934414370243221, tensor_name="edge_8_PyFunc", tensor_type=DT_FLOAT, _device="/job:ps/replica:0/task:0/cpu:0"]()]]
根据N.B.在这个页面(https://www.tensorflow.org/api_docs/python/tf/py_func)中,我认为我没有违反py_func函数的任何规则。例如序列化模型或忘记运行tf.train.Server
和tf.device()
。还有什么我没注意到的吗?