python - tensorflow tfrecord输入管道重复数据读取

时间:2018-05-07 21:28:49

标签: python tensorflow tfrecord

我正在关注tensorflow的官方cifar教程,为我的图像数据集创建一个输入管道。在训练模型后,我决定使用50000张图像进行测试。但是,经过测试,我发现许多图像都经过多次测试,而有些图像根本没有经过测试。例如,一次名为“1000_left”的图像被多次测试,而某些图像如“100_left”则根本没有被测试过。任何人都可以帮我确定发生了什么事吗?谢谢!

以下是我如何加载tfrecord数据集和生成的图像批次以进行测试:

def _generate_image_and_name_batch(image, name, min_queue_examples, batch_size):
  num_preprocess_threads = 16
  images, names = tf.train.batch([image, name], batch_size = batch_size, num_threads = num_preprocess_threads,
                                capacity = min_queue_examples + 3 * batch_size)
  tf.summary.image('images', images)
  return images, tf.reshape(names, [batch_size])


def read_test(test_dir, batch_size):
  #constructs inputs for test
  if not tf.gfile.Exists(test_dir):
    raise ValueError("Failed to find file: " + test_dir)
  #restore features of test record
  features = {"test/image": tf.FixedLenFeature([], tf.string),
            "test/name": tf.FixedLenFeature([], tf.string)}
  with tf.name_scope("test_input"):
    filename_queue = tf.train.string_input_producer(string_tensor = [test_dir])
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(serialized_example, features = features)
    image = tf.decode_raw(features['test/image'], tf.float32)
    image = tf.reshape(image, [224,224,3])
    name = tf.cast(features['test/name'], tf.string)

    min_fraction_of_examples_in_queue = 0.4
    min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TEST * min_fraction_of_examples_in_queue)

  return _generate_image_and_name_batch(image, name, min_queue_examples, batch_size)

以下是我加载测试数据的方法:

#start queue runners
    coord = tf.train.Coordinator()
    try:
        threads = []
        for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
            threads.extend(qr.create_threads(sess, coord = coord, daemon=True, start= True))
        num_iter = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size))
        print("Num iterations for total:", num_iter)
        step = 0

        image_names = []
        all_predictions = []

        while step < num_iter and not coord.should_stop():
            predictions = sess.run([top_1_op])[0]
            img_name = sess.run(name)
            all_predictions = np.concatenate([all_predictions, predictions])
            image_names = np.concatenate([image_names, img_name])
            step += 1

            if step % 100 == 0 or step + 1 == num_iter:
                print("Test step {} has finished".format(step))

    except Exception as e:
        coord.request_stop(e)

    coord.request_stop()
    coord.join(threads, stop_grace_period_secs= 10)

我基本上遵循与cifar示例中相同的步骤和代码。请帮忙!谢谢!

0 个答案:

没有答案