Tensorflow Dataset.from_tensor_slices耗时太长

时间:2017-10-20 19:24:03

标签: python numpy tensorflow tensorflow-datasets

我有以下代码:

data = np.load("data.npy")
print(data) # Makes sure the array gets loaded in memory
dataset = tf.contrib.data.Dataset.from_tensor_slices((data))

文件"data.npy"为3.3 GB。使用numpy读取文件需要几秒钟,但是创建tensorflow数据集对象的下一行需要花费很长时间才能执行。这是为什么?引擎盖下做了什么?

2 个答案:

答案 0 :(得分:4)

引用此answer

  Dataset

np.load只返回文件加载器,而不是实际数据。它是一个“懒惰的加载器”,只在访问时才加载特定的数组。

这就是为什么它很快。

修改1:以扩展此答案,另一个来自tensorflow's documentation的引文:

  

如果您的所有输入数据都适合内存,那么从中创建tf.Tensor的最简单方法是将它们转换为Dataset.from_tensor_slices()个对象并使用{{1}}。

     

这适用于小型数据集,但浪费内存---因为数组的内容将被多次复制---并且可以达到tf.GraphDef协议缓冲区的2GB限制。

该链接还显示了如何有效地完成这项工作。

答案 1 :(得分:0)

尝试:

data = np.load("data.npy")
a = tf.placeholder(tf.float32, shape)
dataset = tf.data.Dataset.from_tensor_slices(a)
dataset = dataset.prefetch(buffer_size=1000)
dataset = dataset.batch(128)
iterator = dataset.make_initializable_iterator()
next_batch = iterator.get_next()
with tf.Session() as sess:
    sess.run(iterator.initializer, feed_dict={a: data})

处理大型数据集时,tf.placeholder更好。