在我的tensorflor管道中,我创建了两个批处理队列:一个用于示例(图像),一个用于标签(整数),基本上与cifar10_input.py处理输入的方式相同:
<read files>
image = tf.image.decode_png(file_contents)
label = tf.string_to_number(label_str, out_type=tf.int32)
# Batch examples here into two queues.
image_batch, label_batch = tf.train.shuffle_batch(
[image, label], ... )
model = build_model(image_batch)
loss = build_loss(model, label_batch)
然后模型将图像队列作为输入,损失评估模型输出和标签之间的差异。
我担心的是,如果我在一批图像上评估仅模型,标签队列将不再是&#34;对齐&#34;并且两个队列(图像和标签)将发散。
(model_output, ) = sess.run([model]) # Uses the image queue as input
(model_output, true_labels) = sess.run([model, label_batch]) # Are image/labels pairs valid?
如何确保两个队列保持同步,以便从每个队列中获取元素将始终返回正确的图像/标签对?
答案 0 :(得分:1)
雅罗斯拉夫的评论是正确的 - 你必须确保他们总是一起出局。
如果您想以与训练步骤交错的方式评估非标记数据的模型,那么您可以实例化模型的另一个副本(权重共享 - 设置重用=真)并从您的读取中读取备用的,非标记的数据源。如果你想要同时进行培训和推理,这不是一个不寻常的模式。
train_predict = model(train_input)
do something with the label here...
alternate:
use_predict = model(use_input, reuse=true)
(reuse = true引用变量作用域)。