我有以下代码:
data = np.load("data.npy")
print(data) # Makes sure the array gets loaded in memory
dataset = tf.contrib.data.Dataset.from_tensor_slices((data))
文件"data.npy"
为3.3 GB。使用numpy读取文件需要几秒钟,但是创建tensorflow数据集对象的下一行需要花费很长时间才能执行。这是为什么?引擎盖下做了什么?
答案 0 :(得分:4)
引用此answer:
Dataset
的
np.load
只返回文件加载器,而不是实际数据。它是一个“懒惰的加载器”,只在访问时才加载特定的数组。
这就是为什么它很快。
修改1:以扩展此答案,另一个来自tensorflow's documentation的引文:
如果您的所有输入数据都适合内存,那么从中创建
tf.Tensor
的最简单方法是将它们转换为Dataset.from_tensor_slices()
个对象并使用{{1}}。这适用于小型数据集,但浪费内存---因为数组的内容将被多次复制---并且可以达到tf.GraphDef协议缓冲区的2GB限制。
该链接还显示了如何有效地完成这项工作。
答案 1 :(得分:0)
尝试:
data = np.load("data.npy")
a = tf.placeholder(tf.float32, shape)
dataset = tf.data.Dataset.from_tensor_slices(a)
dataset = dataset.prefetch(buffer_size=1000)
dataset = dataset.batch(128)
iterator = dataset.make_initializable_iterator()
next_batch = iterator.get_next()
with tf.Session() as sess:
sess.run(iterator.initializer, feed_dict={a: data})
处理大型数据集时,tf.placeholder
更好。