我有以下代码在gcloud中运行它,在没有本地问题的情况下使用带有TF 1.2的CPU(我无法测试GPU的TF< 1.3,但适用于TF 1.3):
...
device_fn = tf.train.replica_device_setter(
ps_device='/job:ps',
worker_device='/job:{}/task:{}'.format(job_name, index),
cluster=tf.train.ClusterSpec({'ps': ps, 'worker': worker, 'master': master}))
with tf.Graph().as_default():
with tf.device(device_fn):
dataset = Dataset(....)
dataset = dataset.map(my_py_fun)
next_tensor = dataset.make_one_shot_iterator().get_next()
create_model(next_tensor)
with tf.train.MonitoredTrainingSession(...
...
当我在gcloud中运行时,我收到此错误:
...
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 1269, in __init__
self._traceback = _extract_stack()
InvalidArgumentError (see above for traceback): Cannot assign a device for operation 'dataset/IteratorGetNext': Operation was explicitly assigned to /job:worker/task:0 but available devices are [ /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/gpu:0 ]. Make sure the device specification refers to a valid device.
[[Node: dataset/IteratorGetNext = IteratorGetNext[output_shapes=[[]], output_types=[DT_STRING], _device="/job:worker/task:0"](dataset/OneShotIterator)]]
...
我想数据集应该只在本地设备中,所以我在tf.device(defice_fn)
之前设置数据集的代码,以便仅在本地机器中创建图形的那一部分(每台机器将读取数据) )。像这样:
...
device_fn = tf.train.replica_device_setter(
ps_device='/job:ps',
worker_device='/job:{}/task:{}'.format(job_name, index),
cluster=tf.train.ClusterSpec({'ps': ps, 'worker': worker, 'master': master}))
with tf.Graph().as_default():
dataset = Dataset(....)
dataset = dataset.map(my_py_fun)
next_tensor = dataset.make_one_shot_iterator().get_next()
with tf.device(device_fn):
create_model(next_tensor)
with tf.train.MonitoredTrainingSession(...
...
然后我得到这个错误(注意IteratorGetNext如何分配给/ job:ps / replica:0 / task:0 / cpu:0):
....
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1152, in _do_call
raise type(e)(node_def, op, message)
UnknownError: exceptions.KeyError: 'pyfunc_0'
[[Node: PyFunc = PyFunc[Tin=[DT_STRING], Tout=[DT_STRING, DT_STRING, DT_STRING, DT_FLOAT, DT_FLOAT, DT_FLOAT], token="pyfunc_0", _device="/device:CPU:*"](arg0)]]
[[Node: dataset/IteratorGetNext_1 = IteratorGetNext[output_shapes=[[-1,3,299,299,3], [-1,3,18333]], output_types=[DT_FLOAT, DT_FLOAT], _device="/job:ps/replica:0/task:0/cpu:0"](dataset/OneShotIterator_1)]]
....
如果用于TF 1.2的GPU或设置设备的问题,这是py_func的问题吗?任何使其适用于TF 1.2(gcloud支持的最新版本)的解决方法?