我有一个非常错误的数据集用于培训。
我正在使用数据集API,如下所示:
self._dataset = tf.contrib.data.Dataset.from_tensor_slices((self._images_list, self._labels_list))
self._dataset = self._dataset.map(self.load_image)
self._dataset = self._dataset.batch(batch_size)
self._dataset = self._dataset.shuffle(buffer_size=shuffle_buffer_size)
self._dataset = self._dataset.repeat()
self._iterator = self._dataset.make_one_shot_iterator()
如果我用于培训,那么一小部分数据都很好。 如果我使用我的所有数据,那么TensorFlow将崩溃并出现此错误: ValueError:GraphDef不能大于2GB。
似乎TensorFlow尝试加载所有数据而不是仅加载它需要的数据......不确定......
任何建议都会很棒!
更新...找到解决方案/解决方法
根据这篇文章:Tensorflow Dataset API doubles graph protobuff filesize
我用make_initializable_iterator()替换了make_one_shot_iterator(),当然在创建会话后调用了iterator初始化器:
init = tf.global_variables_initializer()
sess.run(init)
sess.run(train_data._iterator.initializer)
但是我对这个问题保持开放似乎是一种解决方法而不是解决方案...
答案 0 :(得分:2)
https://www.tensorflow.org/guide/datasets#consuming_numpy_arrays
请注意,以上代码段将把特征和标签数组作为tf.constant()操作嵌入到TensorFlow图中。这对于一个小的数据集来说效果很好,但是浪费了内存-因为数组的内容将被多次复制-并可能达到tf.GraphDef协议缓冲区的2GB限制。 或者,您可以根据tf.placeholder()张量定义数据集,并在对数据集初始化Iterator时提供NumPy数组。
代替使用
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
使用
features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)
dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))