tf.data.Dataset.from_generator不支持uint16,uint32?

时间:2018-03-22 19:45:46

标签: python tensorflow tensorflow-datasets

我有一个带有uint16数据集的HDF5文件,我想使用Tensorflow的Dataset API读取数据。我正在使用from_generator根据需要从HDF5文件中生成项目。但是,我收到有关不受支持的数据类型的错误。

编辑:我最初认为这个问题与我如何使用HDF5有关,但可以用评论中提到的更简单的情况重现它。

以下代码复制了该问题:

import tensorflow as tf
import numpy as np

class generator:
    def __init__(self):
        pass

    def __call__(self):
        yield np.random.random_integers(0,1000).astype(np.uint16)

ds = tf.data.Dataset.from_generator(
    generator(), 
    tf.uint16, 
    tf.TensorShape(None))

next_element = ds.make_one_shot_iterator().get_next()

with tf.Session() as sess:
    while True:
        try:
            my_data = sess.run(next_element)
            print(my_data)
        except tf.errors.OutOfRangeError:
            break

我在软件崩溃时获得以下错误跟踪:

2018-03-22 22:53:07.803119: W tensorflow/core/framework/op_kernel.cc:1190] Unimplemented: Unsupported numpy type 4
2018-03-22 22:53:07.803235: W tensorflow/core/framework/op_kernel.cc:1202] OP_REQUIRES failed at iterator_ops.cc:870 : Unimplemented: Unsupported numpy type 4
     [[Node: PyFunc = PyFunc[Tin=[DT_INT64], Tout=[DT_UINT16], token="pyfunc_1"](arg0)]]
Traceback (most recent call last):
  File "datatype_problem.py", line 30, in <module>
    my_data = sess.run(next_element)
  File "/home/username/tensorflow-remote/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 905, in run
    run_metadata_ptr)
  File "/home/username/tensorflow-remote/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1137, in _run
    feed_dict_tensor, options, run_metadata)
  File "/home/username/tensorflow-remote/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1355, in _do_run
    options, run_metadata)
  File "/home/username/tensorflow-remote/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1374, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.UnimplementedError: Unsupported numpy type 4
     [[Node: PyFunc = PyFunc[Tin=[DT_INT64], Tout=[DT_UINT16], token="pyfunc_1"](arg0)]]
     [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[<unknown>], output_types=[DT_UINT16], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator)]]

部分Unimplemented: Unsupported numpy type 4似乎表明可能不支持数据类型uint16

@KRish在评论中建议一种可行的解决方法似乎有效:替换from_generator调用中的输出数据类型,例如tf.int64。这可行,但可能会导致管道的其余部分需要更多的类型转换。他们还提到this code on github可能解释了缺乏支持。还有其他类型,例如uint32uint64

这里有不支持这些无符号整数类型的原因吗?能够将tf.uint16类型与Dataset.from_generator一起使用的建议解决方法是什么?

0 个答案:

没有答案