tf.contrib.data.Dataset,tf.py_func和gcloud中的分布式培训

时间:2017-09-18 23:42:53

标签: tensorflow dataset gcloud

我有以下代码在gcloud中运行它,在没有本地问题的情况下使用带有TF 1.2的CPU(我无法测试GPU的TF< 1.3,但适用于TF 1.3):

...
device_fn = tf.train.replica_device_setter(
    ps_device='/job:ps',
    worker_device='/job:{}/task:{}'.format(job_name, index),
    cluster=tf.train.ClusterSpec({'ps': ps, 'worker': worker, 'master': master}))
with tf.Graph().as_default():
    with tf.device(device_fn):
        dataset = Dataset(....)
        dataset = dataset.map(my_py_fun)
        next_tensor = dataset.make_one_shot_iterator().get_next()
        create_model(next_tensor)
    with tf.train.MonitoredTrainingSession(...
...

当我在gcloud中运行时,我收到此错误:

...
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 1269, in __init__
    self._traceback = _extract_stack()
InvalidArgumentError (see above for traceback): Cannot assign a device for operation 'dataset/IteratorGetNext': Operation was explicitly assigned to /job:worker/task:0 but available devices are [ /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/gpu:0 ]. Make sure the device specification refers to a valid device.
    [[Node: dataset/IteratorGetNext = IteratorGetNext[output_shapes=[[]], output_types=[DT_STRING], _device="/job:worker/task:0"](dataset/OneShotIterator)]]
...

我想数据集应该只在本地设备中,所以我在tf.device(defice_fn)之前设置数据集的代码,以便仅在本地机器中创建图形的那一部分(每台机器将读取数据) )。像这样:

...
device_fn = tf.train.replica_device_setter(
    ps_device='/job:ps',
    worker_device='/job:{}/task:{}'.format(job_name, index),
    cluster=tf.train.ClusterSpec({'ps': ps, 'worker': worker, 'master': master}))
with tf.Graph().as_default():
    dataset = Dataset(....)
    dataset = dataset.map(my_py_fun)
    next_tensor = dataset.make_one_shot_iterator().get_next()
    with tf.device(device_fn):
        create_model(next_tensor)
    with tf.train.MonitoredTrainingSession(...
...

然后我得到这个错误(注意IteratorGetNext如何分配给/ job:ps / replica:0 / task:0 / cpu:0):

....
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1152, in _do_call
     raise type(e)(node_def, op, message)
UnknownError: exceptions.KeyError: 'pyfunc_0'
     [[Node: PyFunc = PyFunc[Tin=[DT_STRING], Tout=[DT_STRING, DT_STRING, DT_STRING, DT_FLOAT, DT_FLOAT, DT_FLOAT], token="pyfunc_0", _device="/device:CPU:*"](arg0)]]
     [[Node: dataset/IteratorGetNext_1 = IteratorGetNext[output_shapes=[[-1,3,299,299,3], [-1,3,18333]], output_types=[DT_FLOAT, DT_FLOAT], _device="/job:ps/replica:0/task:0/cpu:0"](dataset/OneShotIterator_1)]]
....

如果用于TF 1.2的GPU或设置设备的问题,这是py_func的问题吗?任何使其适用于TF 1.2(gcloud支持的最新版本)的解决方法?

0 个答案:

没有答案