我正在关注tensorflow的官方cifar教程,为我的图像数据集创建一个输入管道。在训练模型后,我决定使用50000张图像进行测试。但是,经过测试,我发现许多图像都经过多次测试,而有些图像根本没有经过测试。例如,一次名为“1000_left”的图像被多次测试,而某些图像如“100_left”则根本没有被测试过。任何人都可以帮我确定发生了什么事吗?谢谢!
以下是我如何加载tfrecord数据集和生成的图像批次以进行测试:
def _generate_image_and_name_batch(image, name, min_queue_examples, batch_size):
num_preprocess_threads = 16
images, names = tf.train.batch([image, name], batch_size = batch_size, num_threads = num_preprocess_threads,
capacity = min_queue_examples + 3 * batch_size)
tf.summary.image('images', images)
return images, tf.reshape(names, [batch_size])
def read_test(test_dir, batch_size):
#constructs inputs for test
if not tf.gfile.Exists(test_dir):
raise ValueError("Failed to find file: " + test_dir)
#restore features of test record
features = {"test/image": tf.FixedLenFeature([], tf.string),
"test/name": tf.FixedLenFeature([], tf.string)}
with tf.name_scope("test_input"):
filename_queue = tf.train.string_input_producer(string_tensor = [test_dir])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example, features = features)
image = tf.decode_raw(features['test/image'], tf.float32)
image = tf.reshape(image, [224,224,3])
name = tf.cast(features['test/name'], tf.string)
min_fraction_of_examples_in_queue = 0.4
min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TEST * min_fraction_of_examples_in_queue)
return _generate_image_and_name_batch(image, name, min_queue_examples, batch_size)
以下是我加载测试数据的方法:
#start queue runners
coord = tf.train.Coordinator()
try:
threads = []
for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
threads.extend(qr.create_threads(sess, coord = coord, daemon=True, start= True))
num_iter = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size))
print("Num iterations for total:", num_iter)
step = 0
image_names = []
all_predictions = []
while step < num_iter and not coord.should_stop():
predictions = sess.run([top_1_op])[0]
img_name = sess.run(name)
all_predictions = np.concatenate([all_predictions, predictions])
image_names = np.concatenate([image_names, img_name])
step += 1
if step % 100 == 0 or step + 1 == num_iter:
print("Test step {} has finished".format(step))
except Exception as e:
coord.request_stop(e)
coord.request_stop()
coord.join(threads, stop_grace_period_secs= 10)
我基本上遵循与cifar示例中相同的步骤和代码。请帮忙!谢谢!