py_func无法处理超过9个项目的列表

时间:2018-01-08 20:49:21

标签: tensorflow tensorflow-datasets

我正在尝试使用Dataset.map向函数提供两个值,然后函数返回一个值列表。当返回列表包含八个以上的元素时,py_func无法处理类型。当map_func的返回列表包含八个元素时,没有问题。

张量流版本在1.4.1发行版中为Trisquel

成功案例

import tensorflow as tf

def gen_range(groups=4, limit=1000):
    jump = limit/groups
    start, stop = 0, 0
    while stop != limit:
        stop = start + jump
        yield start, stop
        start = stop

def bridge(x, y):
    return [[[x] * 4, [y] * 4]]

with tf.Session() as sess:
    dataset = tf.data.Dataset.from_generator(gen_range, (tf.int32, tf.int32)).map(
        lambda x, y: tf.py_func(bridge, [x, y], [tf.int32]), num_parallel_calls=2).\
        make_one_shot_iterator()
    init = tf.global_variables_initializer()
    sess.run(init)
    while True:
        print(sess.run(dataset.get_next()))
产量
(array([[  0,   0,   0,   0],
   [250, 250, 250, 250]], dtype=int32),)
(array([[250, 250, 250, 250],
   [500, 500, 500, 500]], dtype=int32),)
(array([[500, 500, 500, 500],
   [750, 750, 750, 750]], dtype=int32),)
2018-01-09 01:56:12.871943: W tensorflow/core/framework/op_kernel.cc:1192] Out of range: StopIteration: Iteration finished.
(array([[ 750,  750,  750,  750],
   [1000, 1000, 1000, 1000]], dtype=int32),)

失败案例

def bridge(x, y):
    return [[[x] * 5, [y] * 4]]

with tf.Session() as sess:
    dataset = tf.data.Dataset.from_generator(gen_range, (tf.int32, tf.int32)).map(
        lambda x, y: tf.py_func(bridge, [x, y], [tf.int32]), num_parallel_calls=2).\
        make_one_shot_iterator()
    init = tf.global_variables_initializer()
    sess.run(init)
    while True:
        print(sess.run(dataset.get_next()))
产量
2018-01-09 01:56:41.201683: W tensorflow/core/framework/op_kernel.cc:1192] Unimplemented: Unsupported object type list
2018-01-09 01:56:41.201960: W tensorflow/core/framework/op_kernel.cc:1192] Unimplemented: Unsupported object type list
 [[Node: PyFunc = PyFunc[Tin=[DT_INT32, DT_INT32], Tout=[DT_INT32], token="pyfunc_129"](arg0, arg1)]]
 2018-01-09 01:56:41.202002: W tensorflow/core/framework/op_kernel.cc:1192] Unimplemented: Unsupported object type list
---------------------------------------------------------------------------
UnimplementedError                        Traceback (most recent call last)
/media/user/code/tf_data_exploration/venv/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
1322     try:
-> 1323       return fn(*args)
   1324     except errors.OpError as e:

 /media/user/code/tf_data_exploration/venv/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
 1301                                    feed_dict, fetch_list, target_list,
 -> 1302                                    status, run_metadata)
 1303 

 /media/user/code/tf_data_exploration/venv/lib/python3.6/site-packages/tensorflow/python/framework/errors_impl.py in __exit__(self, type_arg, value_arg, traceback_arg)
 472             compat.as_text(c_api.TF_Message(self.status.status)),
 --> 473             c_api.TF_GetCode(self.status.status))
     474     # Delete the underlying status object from memory otherwise it stays alive

 UnimplementedError: Unsupported object type list
 [[Node: PyFunc = PyFunc[Tin=[DT_INT32, DT_INT32], Tout=[DT_INT32], token="pyfunc_129"](arg0, arg1)]]
 [[Node: IteratorGetNext_95 = IteratorGetNext[output_shapes=
 [<unknown>], output_types=[DT_INT32], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator_35)]]

 During handling of the above exception, another exception occurred:

 UnimplementedError                        Traceback (most recent call last)
 <ipython-input-120-120ecc56d75d> in <module>()
       4     sess.run(init)
       5     while True:
 ----> 6         print(sess.run(dataset.get_next()))
       7 
       8 

 /media/user/code/tf_data_exploration/venv/lib/python3.6/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
     887     try:
     888       result = self._run(None, fetches, feed_dict, options_ptr,
 --> 889                          run_metadata_ptr)
     890       if run_metadata:
     891         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

 /media/user/code/tf_data_exploration/venv/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
    1118     if final_fetches or final_targets or (handle and feed_dict_tensor):
    1119       results = self._do_run(handle, final_targets, final_fetches,
 -> 1120                              feed_dict_tensor, options, run_metadata)
    1121     else:
    1122       results = []

 /media/user/code/tf_data_exploration/venv/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
 1315     if handle is None:
 1316       return self._do_call(_run_fn, self._session, feeds, fetches, targets,
 -> 1317                            options, run_metadata)
    1318     else:
    1319       return self._do_call(_prun_fn, self._session, handle, feeds, fetches)

 /media/user/code/tf_data_exploration/venv/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
    1334         except KeyError:
    1335           pass
 -> 1336       raise type(e)(node_def, op, message)
    1337 
    1338   def _extend_graph(self):

 UnimplementedError: Unsupported object type list
     [[Node: PyFunc = PyFunc[Tin=[DT_INT32, DT_INT32], Tout=[DT_INT32], token="pyfunc_129"](arg0, arg1)]]
     [[Node: IteratorGetNext_95 = IteratorGetNext[output_shapes=[<unknown>], output_types=[DT_INT32], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator_35)]]

1 个答案:

答案 0 :(得分:1)

这个问题不是来自大小,而是因为您返回的内容无法转换为Tensor。

当您返回[[[x] * 4, [y] * 4]]时,可以将其转换为形状张量(1, 2, 4)

res = tf.constant([[[x] * 4, [y] * 4]])
print(res.get_shape())  # prints (1, 2, 4)

当您返回[[[x] * 5, [y] * 4]]时,您会为x=1, y=2获得类似的内容:

[[[1, 1, 1, 1, 1],
  [2, 2, 2, 2]
]]

这不能转换为Tensor,因为第一行和第二行的尺寸不匹配。

如果您尝试这样做,可以触发类似的错误:

res = tf.constant([[1, 2], [3]])
  

论证必须是一个密集的张量:[[1,2],[3]] - 得到形状[2],但想要[2,2]。

TensorFlow无法推断出张量的形状。