我想训练并检查此代码中验证集的准确性。我使用is_training
和tf.QueueBase.from_list
来切换列车和计算精度。数据集类型为tfrecord
。推理给出输入图像和浮点数keep_drop_prop
。
#inference, loss, training, evaluation functions ...
train_queue = tf.train.string_input_producer([train_data_path])
test_queue = tf.train.string_input_producer([validation_data_path])
# SELECT QUEUE
is_training = tf.placeholder(tf.bool, shape=None, name="is_training")
q_selector = tf.cond(is_training,
lambda: tf.constant(0),
lambda: tf.constant(1))
q = tf.QueueBase.from_list(q_selector, [train_queue, test_queue])
if is_training==True:
feature = {'train/image': tf.FixedLenFeature([], tf.string),
'train/label': tf.FixedLenFeature([], tf.int64)}
reader = tf.TFRecordReader()
_, serialized_example1 = reader.read(q)
features = tf.parse_single_example(serialized_example, features=feature)
images = tf.decode_raw(features1['train/image'], tf.float32)
labels = tf.cast(features['train/label'], tf.int32)
images = tf.reshape(images, [50, 50, 3])
batch_Xs,batch_Ys=tf.train.shuffle_batch([images,labels],batch_size=500,capacity=500,min_after_dequeue=100)
batch_Xs = tf.cast(batch_Xs,tf.float32)/255
else:
feature = {'validation/image': tf.FixedLenFeature([], tf.string),
'validation/label': tf.FixedLenFeature([], tf.int64)}
reader = tf.TFRecordReader()
_, serialized_example = reader.read(q)
features = tf.parse_single_example(serialized_example, features=feature)
images = tf.decode_raw(features['validation/image'], tf.float32)
labels = tf.cast(features['validation/label'], tf.int32)
images = tf.reshape(images, [50, 50, 3])
batch_Xs,batch_Ys=tf.train.shuffle_batch([images,labels],batch_size=500,capacity=500,min_after_dequeue=100)
batch_Xs = tf.cast(batch_Xs,tf.float32)/255
if is_training==True:
logits=inference(batch_Xs,0.7)
total_loss = loss(logits,batch_Ys)
train_op = training(total_loss,learning_rate=LEARNING_RATE)
accuracy = evaluation(logits,batch_Ys)
else:
logits=inference(batch_Xs,1)
accuracy = evaluation(logits,batch_Ys)
sess = tf.Session()
init = tf.initialize_all_variables()
sess.run(init)
tf.train.start_queue_runners(sess=sess)
for i in range(NUM_ITER):
_,loss_value,acc=sess.run([train_op,total_loss,accuracy],feed_dict={is_training:True})
val_acc=sess.run([accuracy],feed_dict={is_training:False})
此代码的结果是:
UnboundLocalError: local variable 'train_op' referenced before assignment
InvalidArgumentError: You must feed a value for placeholder tensor 'is_training' with dtype bool
[[Node: is_training = Placeholder[dtype=DT_BOOL, shape=[], _device="/job:localhost/replica:0/task:0/gpu:0"]()]]
Caused by op u'is_training', defined at:
File "/home/.../anaconda2/lib/python2.7/runpy.py", line 174, in _run_module_as_main
"__main__", fname, loader, pkg_name)
File "/home/.../anaconda2/lib/python2.7/runpy.py", line 72, in _run_code
exec code in run_globals
File "/home/.../.local/lib/python2.7/site-packages/ipykernel/__main__.py", line 3, in <module>
app.launch_new_instance()
File "/home/.../.local/lib/python2.7/site-packages/traitlets/config/application.py", line 658, in launch_instance
app.start()
File "/home/.../.local/lib/python2.7/site-packages/ipykernel/kernelapp.py", line 474, in start
ioloop.IOLoop.instance().start()
File "/home/.../.local/lib/python2.7/site-packages/zmq/eventloop/ioloop.py", line 177, in start
super(ZMQIOLoop, self).start()
File "/home/.../.local/lib/python2.7/site-packages/tornado/ioloop.py", line 887, in start
handler_func(fd_obj, events)
File "/home/.../.local/lib/python2.7/site-packages/tornado/stack_context.py", line 275, in null_wrapper
return fn(*args, **kwargs)
我认为队列无法访问输入数据......我完全不知道问题出在哪里...非常感谢您的帮助......