使用队列在Tensorflow中将数据提供给网络时,单独验证和培训图表

时间:2017-05-15 15:46:13

标签: python tensorflow deep-learning

我一直在研究如何使用队列正确地将数据传送到网络。但是,我在互联网上找不到任何解决方案。

目前我的代码能够读取培训数据并进行培训,但无需验证和测试。这里有一些重要的行代码:

images, volumes = utils.inputs(FLAGS.train_file_path, FLAGS.batch_size, FLAGS.num_epochs)

print("Initiliaze training")
logits = utils.inference(images)
loss_intermediate, loss = utils.get_loss(logits, volumes)

train_optimizer = utils.pre_training(loss, FLAGS.learning_rate)

summary_train = tf.summary.merge_all('train')
summary_test = tf.summary.merge_all('test')

init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())

saver = tf.train.Saver(max_to_keep=2)
with tf.Session() as sess:

    summary_writer = tf.summary.FileWriter(FLAGS.train_dir + FLAGS.run, sess.graph)
    summary_writer_test = tf.summary.FileWriter(FLAGS.train_dir + FLAGS.run_test, sess.graph)
    sess.run(init)

    # Start input enqueue threads.
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    print("Start training")

    try:
        step = 0
        while not coord.should_stop():
            start_time = time.time()

            _, loss_intermediate_value, loss_value = sess.run([train_optimizer, loss_intermediate, loss])
            duration = time.time() - start_time
            if step % FLAGS.show_step == 0:
                print('Step %d: loss_intermediate = %.2f, loss = %.5f (%.3f sec)' % (step, loss_intermediate_value, loss_value, duration))
                summary_str = sess.run(summary_train)
                summary_writer.add_summary(summary_str, step)
                summary_writer.flush()

            if step % FLAGS.test_interval == 0:
               ###### HERE VALIDATION HOW ? ############
            step += 1
    except tf.errors.OutOfRangeError:
        print('ERROR IN CODE')
    finally:
        print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step))
        # When done, ask the threads to stop.
        coord.request_stop()
        # Wait for threads to finish.
        coord.join(threads)

此功能用于读取数据。

def inputs(train, batch_size, num_epochs):

  if not num_epochs: num_epochs = None
  filename = os.path.join(train)

  with tf.name_scope('input'):
    filename_queue = tf.train.string_input_producer([filename], num_epochs=num_epochs)

    image, volume = read_and_decode(filename_queue)

    images, volumes = tf.train.shuffle_batch([image, volume], batch_size=batch_size, num_threads=2, capacity=1000 * batch_size, min_after_dequeue=500)

    return images, volume

我不明白如何使用张量流来创建另一个输入队列或输入图来进行验证。有人能帮我吗?任何帮助表示赞赏!

修改

def _conv(self, inputs, nb_filter, kernel_size=1, strides=1, pad='VALID', name='conv'):
        with tf.name_scope(name) as scope:

            #kernel = tf.Variable(tf.truncated_normal([kernel_size, kernel_size,int(inputs.get_shape().as_list()[3]),int(nb_filter)], mean=0.0, stddev=0.0001), name='weights')
            kernel = tf.Variable(tf.contrib.layers.xavier_initializer(uniform=False)([kernel_size, kernel_size,int(inputs.get_shape().as_list()[3]),int(nb_filter)]), name='weights')
            conv = tf.nn.conv2d(inputs, kernel, [1,strides,strides,1], padding=pad, data_format='NHWC')
            return conv

编辑2

  with tf.Graph().as_default():
    print("Load Data...")
    images, volumes = utils.inputs(FLAGS.train_file_path, FLAGS.batch_size, FLAGS.num_epochs)
    v_images, v_volumes = utils.inputs(FLAGS.val_file_path, FLAGS.batch_size)

    print("input shape: " + str(images.get_shape()))
    print("output shape: " + str(volumes.get_shape()))

    print("Initialize training")
    logits = utils.inference(images, FLAGS.stacks, True)
    v_logits = utils.inference(v_images, FLAGS.stacks, False)

    tf.add_to_collection("logits", v_logits)

    loss = utils.get_loss(logits, volumes, FLAGS.stacks, 'train')
    v_loss = utils.get_loss(v_logits, v_volumes, FLAGS.stacks, 'val')

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_optimizer = utils.pre_training(loss, FLAGS.learning_rate)

    validate = utils.validate(v_images, v_logits, v_volumes, FLAGS.scale)

    summary_train_op = tf.summary.merge_all('train')
    summary_val_op = tf.summary.merge_all('val')

    init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())

    saver = tf.train.Saver(max_to_keep=2)
    with tf.Session() as sess:

        summary_writer = tf.summary.FileWriter(FLAGS.train_dir + FLAGS.run, sess.graph)
        summary_writer_val = tf.summary.FileWriter(FLAGS.train_dir + FLAGS.run + FLAGS.run_val, sess.graph)
        sess.run(init)

        # Start input enqueue threads.
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            print("Start training")
            step = 0
            while not coord.should_stop():

                start_time = time.time()
                _, loss_list, image_batch, volume_batch, summary_str = sess.run([train_optimizer, loss, images, volumes, summary_train_op])
                duration = time.time() - start_time

                if (step + 1) % FLAGS.show_step == 0:
                    print('Step %d: (%.3f sec)' % (step, duration), end= ': ')
                    print (", ".join('%.5f'%float(x) for x in loss_list))
                    summary_writer.add_summary(summary_str, step)

                if (step + 1) % FLAGS.val_interval == 0:

                    val_loss_sum_list = [0] * len(v_loss)

                    for val_step in range(0, FLAGS.val_iter):
                        _, val_loss_list, summary_str_val, image_input, volume_estimated, volume_ground_truth = sess.run([validate, v_loss, summary_val_op, v_images, v_logits, v_volumes])
                        val_loss_sum_list = [sum(x) for x in zip(val_loss_sum_list, val_loss_list)]

                        if (val_step + 1) == FLAGS.val_iter:
                            print('Validation Interval %d: ' % (step / FLAGS.val_interval), end= '')
                            print (", ".join('%.5f'%float(x / FLAGS.val_iter) for x in val_loss_sum_list))
                            summary_writer_val.add_summary(summary_str_val, step)

                            #image_input, volume_estimated, volume_ground_truth = sess.run([v_images, v_logits, v_volumes])
                            #summary_val_images_op = utils.validate(image_input, volume_estimated, volume_ground_truth, FLAGS.scale, int(step / FLAGS.val_interval))

                if (step + 1) % FLAGS.step_save_checkpoint == 0:
                    checkpoint_file = os.path.join(FLAGS.train_dir + FLAGS.run, 'hourglass-model')
                    saver.save(sess, checkpoint_file, global_step=step)
                    print('Step: ' + str(step))
                    print('Saved: ' + checkpoint_file)

                step += 1
        except tf.errors.OutOfRangeError:
            print('OUT OF RANGE ERROR')
        except Exception as e:
            print(sys.exc_info())
            print('Unexpected error in code')
            exc_type, exc_obj, exc_tb = sys.exc_info()
            fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
            print(exc_type, fname, exc_tb.tb_lineno)
        finally:
            print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step))
            checkpoint_file = os.path.join(FLAGS.train_dir + FLAGS.run, '-model')
            saver.save(sess, checkpoint_file, global_step=step)
            print('Step: ' + str(step))
            print('Saved: ' + checkpoint_file)

            # When done, ask the threads to stop.
            coord.request_stop()
            # Wait for threads to finish.
            coord.join(threads)

1 个答案:

答案 0 :(得分:0)

如果您已将数据拆分为培训和验证数据集,则只需为验证数据创建另一个输入管道即可。使用您提供的代码应该看起来像这样

images, volumes = utils.inputs(FLAGS.train_file_path, FLAGS.batch_size, FLAGS.num_epochs)
# create validation pipeline
v_images, v_volumes = utils.inputs(FLAGS.valid_file_path, FLAGS.batch_size, None)

logits = utils.inference(images)
loss_intermediate, loss = utils.get_loss(logits, volumes)
# define validation ops
v_logits = utils.inference(v_images)
accuracy = utils.accuracy(v_logits, v_volumes)

... a bunch of code here ...

with tf.Session() as sess:
    ... more code here ...
    if step % FLAGS.test_interval == 0:
        acc = sess.run([accuracy])
        print('Accuracy on validation data: {}'.format(acc))
    ... more code here ...

这是你一直在寻找的吗?