所以我想使用数据集API来批量处理大型数据集(约8GB),因为在使用GPU时,由于我正在使用feed_dict将数据从python传递到Tensorflow上,因此存在大量空闲时间。
当我按照此处提到的教程进行操作时:
运行我的简单代码时:
one_hot_dataset = np.load("one_hot_dataset.npy")
dataset = tf.data.Dataset.from_tensor_slices(one_hot_dataset)
我在TensorFlow 1.8和Python 3.5中收到错误消息:
Traceback (most recent call last):
File "<ipython-input-17-412a606c772f>", line 1, in <module>
dataset = tf.data.Dataset.from_tensor_slices((one_hot_dataset))
File "/anaconda2/envs/tf/lib/python3.5/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 235, in from_tensor_slices
return TensorSliceDataset(tensors)
File "/anaconda2/envs/tf/lib/python3.5/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1030, in __init__
for i, t in enumerate(nest.flatten(tensors))
File "/anaconda2/envs/tf/lib/python3.5/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1030, in <listcomp>
for i, t in enumerate(nest.flatten(tensors))
File "/anaconda2/envs/tf/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1014, in convert_to_tensor
as_ref=False)
File "/anaconda2/envs/tf/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1104, in internal_convert_to_tensor
ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
File "/anaconda2/envs/tf/lib/python3.5/site-packages/tensorflow/python/framework/constant_op.py", line 235, in _constant_tensor_conversion_function
return constant(v, dtype=dtype, name=name)
File "/anaconda2/envs/tf/lib/python3.5/site-packages/tensorflow/python/framework/constant_op.py", line 214, in constant
value, dtype=dtype, shape=shape, verify_shape=verify_shape))
File "/anaconda2/envs/tf/lib/python3.5/site-packages/tensorflow/python/framework/tensor_util.py", line 496, in make_tensor_proto
"Cannot create a tensor proto whose content is larger than 2GB.")
ValueError: Cannot create a tensor proto whose content is larger than 2GB.
我该如何解决?我认为原因很明显,但是tf开发人员通过将输入数据限制为2GB怎么了?我真的无法理解这种合理性,在处理较大的数据集时该如何解决?
我用Google搜索了很多,但是找不到类似的错误消息。当我使用numpy数据集的FITFH时,以上步骤可以正常工作。
我某种程度上需要告诉TensorFlow我实际上将逐批加载数据,并且可能想预取一些批处理以保持我的GPU繁忙。但是似乎正在尝试一次加载整个numpy数据集。那么使用Dataset API有什么好处,因为我能够通过简单地尝试将numpy数据集作为tf.constant加载到TensorFlow图中来重现此错误,这显然不适合并且出现OOM错误。 / p>
感谢技巧和故障排除提示!
答案 0 :(得分:2)
tf.data
用户指南(https://www.tensorflow.org/guide/datasets)“使用NumPy数组”部分中解决了此问题。
基本上,创建一个dataset.make_initializable_iterator()
迭代器并在运行时提供数据。
如果由于某种原因该方法不起作用,则可以将数据写入文件或从Python生成器(https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_generator)创建数据集,在其中可以放置任意Python代码,包括切片numpy数组并产生切片