TensorFlow - 动态输入批量大小?

时间:2018-02-03 01:28:38

标签: python tensorflow

就我而言,我需要在训练期间动态更改batch_size。例如,我需要每10个纪元加倍batch_size。但问题是,虽然我知道如何使其动态化,但在 输入管道 中,我必须确定批量大小,如下面的代码所示。也就是说,使用tf.train.shuffle_batch我必须确定batch_size参数,之后我找不到任何改变它的方法。因此,我将不胜感激任何建议!如何制作动态输入批处理?

  filename_queue = tf.train.string_input_producer([self.tfrecords_file])
  reader = tf.TFRecordReader()

  _, serialized_example = self.reader.read(filename_queue)
  features = tf.parse_single_example(
      serialized_example,
      features={
        'image/file_name': tf.FixedLenFeature([], tf.string),
        'image/encoded_image': tf.FixedLenFeature([], tf.string),
      })

  image_buffer = features['image/encoded_image']
  image = tf.image.decode_jpeg(image_buffer, channels=3)
  image = self._preprocess(image)
  images = tf.train.shuffle_batch(
        [image], batch_size=self.batch_size, num_threads=self.num_threads,
        capacity=self.min_queue_examples + 3*self.batch_size,
        min_after_dequeue=self.min_queue_examples
      )

1 个答案:

答案 0 :(得分:0)

我相信你想要做的是以下(我没试过这个,所以如果我犯了错误就纠正我。)

为批量大小创建占位符:

batch_size_placeholder = tf.placeholder(tf.int64)

使用占位符创建您的随机播放批处理:

images = tf.train.shuffle_batch(
    [image], batch_size=self.batch_size, num_threads=self.num_threads,
    capacity=self.min_queue_examples + 3*batch_size_placeholder,
    min_after_dequeue=self.min_queue_examples
  )

通过拨打sess.run

来传递批量大小
sess.run(my_optimizer, feed_dict={batch_size_placeholder: my_dynamic_batch_size})

我希望shuffle_batch能接受张量。

如果存在任何问题,您可以考虑使用数据集管道。这是进行数据流水线操作的更新,更有趣,更有光泽的方法。

https://www.tensorflow.org/programmers_guide/datasets