Tensorflow培训和验证输入队列分离

时间:2016-08-24 06:40:39

标签: machine-learning computer-vision tensorflow deep-learning

我尝试使用TensorFlow复制完全卷积网络结果。我用了Marvin Teichmann's implementation from github。我只需要编写训练包装器。我创建了两个共享变量和两个输入队列的图,一个用于训练,一个用于验证。为了测试我的训练包装器,我使用了两个简短的训练和验证文件列表,并在每个训练时期后立即进行验证。我还从输入队列打印出每个图像的形状,以检查我是否得到了正确的输入。然而,在我开始训练之后,似乎只有来自训练队列的图像正在出列。因此,训练和验证图都从训练队列中获取输入,并且永远不会访问验证队列。任何人都可以帮助解释和解决这个问题吗?

以下是相关代码的一部分:

def get_data(image_name_list, num_epochs, scope_name, num_class = NUM_CLASS):
    with tf.variable_scope(scope_name) as scope:
        images_path = [os.path.join(DATASET_DIR, i+'.jpg') for i in image_name_list]
        gts_path = [os.path.join(GT_DIR, i+'.png') for i in image_name_list]
        seed = random.randint(0, 2147483647)
        image_name_queue = tf.train.string_input_producer(images_path, num_epochs=num_epochs, shuffle=False, seed = seed)
        gt_name_queue = tf.train.string_input_producer(gts_path, num_epochs=num_epochs, shuffle=False, seed = seed)
        reader = tf.WholeFileReader()
        image_key, image_value = reader.read(image_name_queue)
        my_image = tf.image.decode_jpeg(image_value)
        my_image = tf.cast(my_image, tf.float32)
        my_image = tf.expand_dims(my_image, 0)
        gt_key, gt_value = reader.read(gt_name_queue)
        # gt stands for ground truth
        my_gt = tf.cast(tf.image.decode_png(gt_value, channels = 1), tf.float32)
        my_gt = tf.one_hot(tf.cast(my_gt, tf.int32), NUM_CLASS)
        return my_image, my_gt

train_image, train_gt = get_data(train_files, NUM_EPOCH, 'training')
val_image, val_gt = get_data(val_files, NUM_EPOCH, 'validation')
with tf.variable_scope('FCN16') as scope:
        train_vgg16_fcn = fcn16_vgg.FCN16VGG()
        train_vgg16_fcn.build(train_image, train=True, num_classes=NUM_CLASS, keep_prob = KEEP_PROB)
        scope.reuse_variables()
        val_vgg16_fcn = fcn16_vgg.FCN16VGG()
        val_vgg16_fcn.build(val_image, train=False, num_classes=NUM_CLASS, keep_prob = 1)
"""
Define the loss, evaluation metric, summary, saver in the computation graph. Initialize variables and start a session.
"""
for epoch in range(starting_epoch, NUM_EPOCH):
    for i in range(train_num):
        _, loss_value, shape = sess.run([train_op, train_entropy_loss, tf.shape(train_image)])
        print shape
    for i in range(val_num):
        loss_value, shape = sess.run([val_entropy_loss, tf.shape(val_image)])
        print shape

1 个答案:

答案 0 :(得分:0)

要确保您正在阅读不同的图像,您可以运行:

[train_image_np, val_image_np] = sess.run([train_image, val_image])

要重用变量,这会更好,更安全:

with tf.variable_scope('FCN16') as scope:
   train_vgg16_fcn = fcn16_vgg.FCN16VGG()  
   train_vgg16_fcn.build(train_image, train=True, num_classes=NUM_CLASS, keep_prob = KEEP_PROB)
with tf.variable_scope(scope, reuse=True):
   val_vgg16_fcn = fcn16_vgg.FCN16VGG()
   val_vgg16_fcn.build(val_image, train=False, num_classes=NUM_CLASS, keep_prob = 1)