我目前正在通过移植来自此tutorial的inception-v3重新训练脚本中的功能,在Tensorflow中尝试新的高级tf.contrib.learn API:
我是否知道如何在原始的retrain.py中看到,对于验证和训练输入,我是如何复制每次迭代的小批量采样?
目前我尝试在input_fn函数中使用tf.train.shuffle_batch但不是我不确定它是否有效。
为清晰起见,部分代码段
def train_input_fn():
# Get a batch of input bottleneck values, either calculated fresh every
# time with distortions applied, or from the cache stored on disk.
if do_distort_images:
(train_bottleneck_outputs, train_ground_truths) = get_random_distorted_bottlenecks(
sess, image_lists, -1, 'training',
FLAGS.image_dir, distorted_jpeg_data_tensor,
distorted_image_tensor, resized_image_tensor, bottleneck_tensor)
else:
(train_bottleneck_outputs, train_ground_truths, _) = get_random_cached_bottlenecks(
sess, image_lists, -1, 'training',
FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor,
bottleneck_tensor)
return tf.train.shuffle_batch([tf.constant(train_bottleneck_outputs),
tf.constant(train_ground_truths)],
batch_size=FLAGS.train_batch_size, capacity=1100,
min_after_dequeue=1000, enqueue_many=True, num_threads=2)