TensorFlow存储在Graph中的整个Datatset

时间:2018-04-18 15:32:16

标签: python tensorflow

我正致力于使用Cifar-10数据集开发CNN并将数据提供给网络,我使用数据集API将可馈送迭代器与句柄占位符一起使用:https://www.tensorflow.org/programmers_guide/datasets#creating_an_iterator。我个人非常喜欢这种方法,因为它提供了一种清晰简单的方法来将数据提供给网络并在我的测试和验证集之间切换。但是,当我在训练结束时保存图形时,创建的.meta文件与我开始使用的测试数据一样大。我正在使用这些操作来提供对输入占位符和输出操作符的访问:

tf.get_collection("validation_nodes")
tf.add_to_collection("validation_nodes", input_data)
tf.add_to_collection("validation_nodes", input_labels)
tf.add_to_collection("validation_nodes", predict)

然后使用以下内容保存图表: 训练前:

saver = tf.train.Saver()

训练结束后:

save_path = saver.save(sess, "./my_model")

有没有办法阻止TensorFlow存储图表中的所有数据?提前谢谢!

1 个答案:

答案 0 :(得分:0)

您正在为数据集创建tf.constant,这就是为什么它已添加到图表定义中的原因。解决方案是使用可初始化的迭代器并定义占位符。在开始对图形运行操作之前,您要做的第一件事就是将数据集提供给它。请参阅"创建迭代器"下的程序员指南。部分为例。

https://www.tensorflow.org/programmers_guide/datasets#creating_an_iterator

我做的完全一样,所以这里是我用来准确描述你的描述的相关代码部分的复制/粘贴(使用可初始化迭代器的cifar10的训练/测试集):

  def build_datasets(self):
    """ Creates a train_iterator and test_iterator from the two datasets. """
    self.imgs_4d_uint8_placeholder = tf.placeholder(tf.uint8, [None, 32, 32, 3], 'load_images_placeholder')
    self.imgs_4d_float32_placeholder = tf.placeholder(tf.float32, [None, 32, 32, 3], 'load_images_float32_placeholder')
    self.labels_1d_uint8_placeholder = tf.placeholder(tf.uint8, [None], 'load_labels_placeholder')
    self.load_data_train = tf.data.Dataset.from_tensor_slices({
      'data': self.imgs_4d_uint8_placeholder,
      'labels': self.labels_1d_uint8_placeholder
    })
    self.load_data_test = tf.data.Dataset.from_tensor_slices({
      'data': self.imgs_4d_uint8_placeholder,
      'labels': self.labels_1d_uint8_placeholder
    })
    self.load_data_adversarial = tf.data.Dataset.from_tensor_slices({
      'data': self.imgs_4d_float32_placeholder,
      'labels': self.labels_1d_uint8_placeholder
    })

    # Train dataset pipeline
    dataset_train = self.load_data_train
    dataset_train = dataset_train.shuffle(buffer_size=50000)
    dataset_train = dataset_train.repeat()
    dataset_train = dataset_train.map(self._img_augmentation, num_parallel_calls=8)
    dataset_train = dataset_train.map(self._img_preprocessing, num_parallel_calls=8)
    dataset_train = dataset_train.batch(self.hyperparams['batch_size'])
    dataset_train = dataset_train.prefetch(2)
    self.iterator_train = dataset_train.make_initializable_iterator()

    # Test dataset pipeline
    dataset_test = self.load_data_test
    dataset_test = dataset_test.map(self._img_preprocessing, num_parallel_calls=8)
    dataset_test = dataset_test.batch(self.hyperparams['batch_size'])
    self.iterator_test = dataset_test.make_initializable_iterator()



  def init(self, sess):
    self.cifar10 = Cifar10()    # a class I wrote for loading cifar10
    self.handle_train = sess.run(self.iterator_train.string_handle())
    self.handle_test = sess.run(self.iterator_test.string_handle())
    sess.run(self.iterator_train.initializer, feed_dict={self.handle: self.handle_train,
                                                         self.imgs_4d_uint8_placeholder: self.cifar10.train_data,
                                                         self.labels_1d_uint8_placeholder: self.cifar10.train_labels})