TensorFlow DataSet API导致图形大小爆炸

时间:2017-09-13 21:38:27

标签: tensorflow

我有一个非常错误的数据集用于培训。

我正在使用数据集API,如下所示:

self._dataset = tf.contrib.data.Dataset.from_tensor_slices((self._images_list, self._labels_list))

self._dataset = self._dataset.map(self.load_image)

self._dataset = self._dataset.batch(batch_size)
self._dataset = self._dataset.shuffle(buffer_size=shuffle_buffer_size)
self._dataset = self._dataset.repeat()

self._iterator = self._dataset.make_one_shot_iterator()

如果我用于培训,那么一小部分数据都很好。 如果我使用我的所有数据,那么TensorFlow将崩溃并出现此错误: ValueError:GraphDef不能大于2GB。

似乎TensorFlow尝试加​​载所有数据而不是仅加载它需要的数据......不确定......

任何建议都会很棒!

更新...找到解决方案/解决方法

根据这篇文章:Tensorflow Dataset API doubles graph protobuff filesize

我用make_initializable_iterator()替换了make_one_shot_iterator(),当然在创建会话后调用了iterator初始化器:

init = tf.global_variables_initializer()
sess.run(init)
sess.run(train_data._iterator.initializer)

但是我对这个问题保持开放似乎是一种解决方法而不是解决方案...

1 个答案:

答案 0 :(得分:2)

https://www.tensorflow.org/guide/datasets#consuming_numpy_arrays

  

请注意,以上代码段将把特征和标签数组作为tf.constant()操作嵌入到TensorFlow图中。这对于一个小的数据集来说效果很好,但是浪费了内存-因为数组的内容将被多次复制-并可能达到tf.GraphDef协议缓冲区的2GB限制。     或者,您可以根据tf.placeholder()张量定义数据集,并在对数据集初始化Iterator时提供NumPy数组。

代替使用

dataset = tf.data.Dataset.from_tensor_slices((features, labels))

使用

features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)

dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))