在数据集解析器函数中加载NumPy数组

时间:2018-08-15 00:53:47

标签: tensorflow

我正在使用TensorFlow数据集来使用硬盘中的数据。数据存储在NumPy数组中,而NumPy数组的路径存储在文本文件中。创建数据集时,我正在使用dataset.map()函数将每个路径映射到NumPy数组。

这是我代码的相关部分:

def parser(path):
    x = np.load(path)
    return x

paths = ['data1.npy', 'data2.npy', 'data3.npy', 'data4.npy', ... ]

dataset = tf.data.Dataset.from_tensor_slices((paths))
dataset = dataset.map(map_func=parser)

但是,这会出现以下错误:

AttributeError: 'Tensor' object has no attribute 'read'

该错误涉及到行x = np.load(path)。看来我无法以这种方式在解析器函数中加载NumPy数组,因为path实际上不是字符串,而是张量。

正确的方法是什么?我想尽可能避免使用TFRecords。


我还尝试了如下包装load函数:

x = tf.py_func(np.load(path))

但这在那条线上给了我同样的错误:

AttributeError: 'Tensor' object has no attribute 'read'

1 个答案:

答案 0 :(得分:0)

您会收到此错误,因为np.load需要输入字符串,但会得到Tensor。 您可以使用tf.py_func包装加载函数。