TensorFlow:在不拖曳的情况下读取队列中的图像

时间:2016-04-22 02:08:25

标签: image queue tensorflow

我有614张图片的训练集,这些图像已经被洗牌。我希望按批次顺序读取图像。由于我的标签排列顺序相同,因此在读入批次时对图像进行任何改组都会导致标签错误。

这些是我阅读和添加图像到批处理的功能:

# To add files from queue to a batch:
def add_to_batch(image):

    print('Adding to batch')
    image_batch = tf.train.batch([image],batch_size=5,num_threads=1,capacity=614)

    # Add to summary
    tf.image_summary('images',image_batch,max_images=30)

    return image_batch

# To read files in queue and process:
def get_batch():

    # Create filename queue of images to read
    filenames = [('/media/jessica/Jessica/TensorFlow/StreetView/training/original/train_%d.png' % i) for i in range(1,614)]
    filename_queue =   tf.train.string_input_producer(filenames,shuffle=False,capacity=614)
    reader = tf.WholeFileReader()
    key, value = reader.read(filename_queue)

    # Read and process image
    # Image is 500 x 275:
    my_image = tf.image.decode_png(value)
    my_image_float = tf.cast(my_image,tf.float32)
    my_image_float = tf.reshape(my_image_float,[275,500,4])

    return add_to_batch(my_image_float)

这是我执行预测的功能:

def inference(x):

    < Perform convolution, pooling etc.>

    return y_conv

这是我计算损失和执行优化的功能:

def train_step(y_label,y_conv):

    """ Calculate loss """
    # Cross-entropy
    loss = -tf.reduce_sum(y_label*tf.log(y_conv + 1e-9))

    # Add to summary
    tf.scalar_summary('loss',loss)

    """ Optimisation """
    opt = tf.train.AdamOptimizer().minimize(loss)

    return loss

这是我的主要功能:

def main ():

    # Training
    images = get_batch()
    y_conv = inference(images)
    loss = train_step(y_label,y_conv)

    # To write and merge summaries
    writer = tf.train.SummaryWriter('/media/jessica/Jessica/TensorFlow/StreetView/SummaryLogs/log_5', graph_def=sess.graph_def)
    merged = tf.merge_all_summaries()

    """ Run session """
    sess.run(tf.initialize_all_variables())
    tf.train.start_queue_runners(sess=sess)

    print "Running..."
    for step in range(5):

        # y_1 = <get the correct labels here>

        # Train
        loss_value = sess.run(train_step,feed_dict={y_label:y_1})
        print "Step %d, Loss %g"%(step,loss_value)

        # Save summary
        summary_str = sess.run(merged,feed_dict={y_label:y_1})
        writer.add_summary(summary_str,step)

    print('Finished')

if __name__ == '__main__':
  main()

当我检查image_summary时,图像似乎不是按顺序排列的。或者说,正在发生的事情是:

图像1-5:丢弃,图像6-10:读取,图像11-15:丢弃,图像16-20:读取等

所以看起来我的批次是两次,扔掉第一个并使用第二个?我尝试了一些补救措施,但似乎没有任何效果。我觉得我在理解调用images = get_batch()sess.run()

的根本错误

1 个答案:

答案 0 :(得分:4)

您的batch操作是FIFOQueue,因此每次使用它的输出时,都会提升状态。

您的第一个session.run来电使用train_step计算中的图片1-5,您的第二个session.run要求计算image_summary拉取图像5-6并在可视化中使用它们。

如果要在不影响输入状态的情况下可视化事物,则有助于在变量中缓存队列值,并将变量定义为输入而不是依赖于实时队列。

(image_batch_live,) = tf.train.batch([image],batch_size=5,num_threads=1,capacity=614)

image_batch = tf.Variable(
  tf.zeros((batch_size, image_size, image_size, color_channels)),
  trainable=False,
  name="input_values_cached")

advance_batch = tf.assign(image_batch, image_batch_live)

所以现在你的image_batch是一个静态值,你可以用它来计算损失和可视化。在步骤之间,您可以调用sess.run(advance_batch)来推进队列。

使用此方法产生轻微皱纹 - 默认保护程序会将image_batch变量保存到检查点。如果您更改批量大小,则检查点还原将失败,并且维度不匹配。要解决此问题,您需要指定要手动恢复的变量列表,并为其余部分运行初始化程序。