Tensorflow数据集API-.from_tensor_slices()/ .from_tensor()-无法创建内容大于2gb的张量原型

时间:2018-06-30 21:17:45

标签: python python-3.x tensorflow pipeline tensorflow-datasets

所以我想使用数据集API来批量处理大型数据集(约8GB),因为在使用GPU时,由于我正在使用feed_dict将数据从python传递到Tensorflow上,因此存在大量空闲时间。

当我按照此处提到的教程进行操作时:

https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/5_DataManagement/tensorflow_dataset_api.py

运行我的简单代码时:

one_hot_dataset = np.load("one_hot_dataset.npy")
dataset = tf.data.Dataset.from_tensor_slices(one_hot_dataset)

我在TensorFlow 1.8和Python 3.5中收到错误消息:

Traceback (most recent call last):

  File "<ipython-input-17-412a606c772f>", line 1, in <module>
    dataset = tf.data.Dataset.from_tensor_slices((one_hot_dataset))

  File "/anaconda2/envs/tf/lib/python3.5/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 235, in from_tensor_slices
    return TensorSliceDataset(tensors)

  File "/anaconda2/envs/tf/lib/python3.5/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1030, in __init__
    for i, t in enumerate(nest.flatten(tensors))

  File "/anaconda2/envs/tf/lib/python3.5/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1030, in <listcomp>
    for i, t in enumerate(nest.flatten(tensors))

  File "/anaconda2/envs/tf/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1014, in convert_to_tensor
    as_ref=False)

  File "/anaconda2/envs/tf/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1104, in internal_convert_to_tensor
    ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)

  File "/anaconda2/envs/tf/lib/python3.5/site-packages/tensorflow/python/framework/constant_op.py", line 235, in _constant_tensor_conversion_function
    return constant(v, dtype=dtype, name=name)

  File "/anaconda2/envs/tf/lib/python3.5/site-packages/tensorflow/python/framework/constant_op.py", line 214, in constant
    value, dtype=dtype, shape=shape, verify_shape=verify_shape))

  File "/anaconda2/envs/tf/lib/python3.5/site-packages/tensorflow/python/framework/tensor_util.py", line 496, in make_tensor_proto
    "Cannot create a tensor proto whose content is larger than 2GB.")

ValueError: Cannot create a tensor proto whose content is larger than 2GB.

我该如何解决?我认为原因很明显,但是tf开发人员通过将输入数据限制为2GB怎么了?我真的无法理解这种合理性,在处理较大的数据集时该如何解决?

我用Google搜索了很多,但是找不到类似的错误消息。当我使用numpy数据集的FITFH时,以上步骤可以正常工作。

我某种程度上需要告诉TensorFlow我实际上将逐批加载数据,并且可能想预取一些批处理以保持我的GPU繁忙。但是似乎正在尝试一次加载整个numpy数据集。那么使用Dataset API有什么好处,因为我能够通过简单地尝试将numpy数据集作为tf.constant加载到TensorFlow图中来重现此错误,这显然不适合并且出现OOM错误。 / p>

感谢技巧和故障排除提示!

1 个答案:

答案 0 :(得分:2)

tf.data用户指南(https://www.tensorflow.org/guide/datasets)“使用NumPy数组”部分中解决了此问题。

基本上,创建一个dataset.make_initializable_iterator()迭代器并在运行时提供数据。

如果由于某种原因该方法不起作用,则可以将数据写入文件或从Python生成器(https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_generator)创建数据集,在其中可以放置任意Python代码,包括切片numpy数组并产生切片