在输入fn中实现minibatch

时间:2017-06-06 14:52:26

标签: machine-learning tensorflow deep-learning

我目前正在通过移植来自此tutorial的inception-v3重新训练脚本中的功能,在Tensorflow中尝试新的高级tf.contrib.learn API:

我是否知道如何在原始的retrain.py中看到,对于验证和训练输入,我是如何复制每次迭代的小批量采样?

目前我尝试在input_fn函数中使用tf.train.shuffle_batch但不是我不确定它是否有效。

为清晰起见,部分代码段

def train_input_fn():
    # Get a batch of input bottleneck values, either calculated fresh every
    # time with distortions applied, or from the cache stored on disk.
    if do_distort_images:
        (train_bottleneck_outputs, train_ground_truths) = get_random_distorted_bottlenecks(
                sess, image_lists, -1, 'training',
                FLAGS.image_dir, distorted_jpeg_data_tensor,
                distorted_image_tensor, resized_image_tensor, bottleneck_tensor)
    else:
        (train_bottleneck_outputs, train_ground_truths, _) = get_random_cached_bottlenecks(
                sess, image_lists, -1, 'training',
                FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor,
                bottleneck_tensor)
    return tf.train.shuffle_batch([tf.constant(train_bottleneck_outputs),
                                     tf.constant(train_ground_truths)],
                                      batch_size=FLAGS.train_batch_size, capacity=1100,
                                      min_after_dequeue=1000, enqueue_many=True, num_threads=2)

0 个答案:

没有答案