我试图在我的模型中实现一个从TFRecords二进制文件读取的输入管道; 每个二进制文件包含一个示例(图像,标签,我需要的其他东西)
我有一个带文件路径列表的文本文件;然后:
如果有人暗示我做错了什么,请告诉我
我的模型测试代码的简化版本如下; 谢谢!
def my_input(file_list, batch_size)
filename = []
f = open(file_list, 'r')
for line in f:
filename.append(params.TEST_RECORDS_DATA_DIR + line[:-1])
filename_queue = tf.train.string_input_producer(filename)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'image_raw': tf.FixedLenFeature([], tf.string),
'label_raw': tf.FixedLenFeature([], tf.string),
'name': tf.FixedLenFeature([], tf.string)
})
image = tf.decode_raw(features['image_raw'], tf.uint8)
image.set_shape(params.IMAGE_HEIGHT*params.IMAGE_WIDTH*3)
image = tf.reshape(image, (params.IMAGE_HEIGHT,params.IMAGE_WIDTH,3))
image = tf.cast(image, tf.float32)/255.0
image = preprocess(image)
label = tf.decode_raw(features['label_raw'], tf.uint8)
label.set_shape(params.NUM_CLASSES)
name = features['name']
images, labels, image_names = tf.train.batch([image, label, name],
batch_size=batch_size, num_threads=2,
capacity=1000 + 3 * batch_size, min_after_dequeue=1000)
return images, labels, image_names
def main()
with tf.Graph().as_default():
# call input operations
images, labels, image_names = my_input(file_list=params.TEST_FILE_LIST, batch_size=params.BATCH_SIZE)
# load a trained model and make predictions
prediction = infer(images, labels, image_names)
with tf.Session() as sess:
for step in range(params.N_STEPS):
prediction_values = sess.run([prediction])
# process output
return
答案 0 :(得分:0)
我的猜测是tf.train.string_input_producer(filename)
被设置为无限期地生成文件名,如果你在多个(2
)线程中批处理示例,可能是一个线程已经开始处理文件的情况第二次,而另一个还没有设法完成第一轮。要准确读取每个示例,请使用:
tf.train.string_input_producer(filename, num_epochs=1)
并在会话开始时初始化局部变量:
sess.run(tf.initialize_local_variables())