我正在CIFAR-10数据集上运行CNN模型。该代码取自tensorflow的this教程。我需要在每10个时间段损失的同时打印出训练的准确性。因此,loss()
函数中计算出的准确性。
def train():
"""Train CIFAR-10 for a number of steps."""
with tf.Graph().as_default():
global_step = tf.train.get_or_create_global_step()
# Get images and labels for CIFAR-10.
with tf.device('/cpu:0'):
images, labels = cifar10.distorted_inputs()
# Build a Graph that computes the logits predictions from the
# inference model.
logits = cifar10.inference(images)
# Calculate loss.
loss, accuracy = cifar10.loss(logits, labels)
# Build a Graph that trains the model with one batch of examples and
# updates the model parameters.
train_op = cifar10.train(loss, global_step)
class _LoggerHook(tf.train.SessionRunHook):
"""Logs loss and runtime."""
def begin(self):
self._step = -1
self._start_time = time.time()
def before_run(self, run_context):
self._step += 1
return tf.train.SessionRunArgs({"loss":loss, "acc":accuracy}) # Asks for loss value.
def after_run(self, run_context, run_values):
if self._step % FLAGS.log_frequency == 0:
current_time = time.time()
duration = current_time - self._start_time
self._start_time = current_time
loss_value = run_values.results['loss']
acc = run_values.results['acc']
examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
sec_per_batch = float(duration / FLAGS.log_frequency)
format_str = ('%s: step %d, loss = %.2f, accuracy = %.3f (%.1f examples/sec; %.3f '
'sec/batch)')
print (format_str % (datetime.now(), self._step, loss_value, acc*100,
examples_per_sec, sec_per_batch))
with tf.train.MonitoredTrainingSession(
checkpoint_dir=FLAGS.train_dir,
hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
tf.train.NanTensorHook(loss),
_LoggerHook()],
config=tf.ConfigProto(
log_device_placement=FLAGS.log_device_placement)) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op)
这是loss()
函数。前两行计算准确性。
def loss(logits, labels):
"""Add L2Loss to all the trainable variables.
Add summary for "Loss" and "Loss/avg".
Args:
logits: Logits from inference().
labels: Labels from distorted_inputs or inputs(). 1-D tensor
of shape [batch_size]
Returns:
Loss tensor of type float.
"""
# Calculate the average cross entropy loss across the batch.
correct_pred = tf.equal(tf.argmax(logits, -1), tf.argmax(labels, -1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
labels = tf.cast(labels, tf.int64)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=logits, name='cross_entropy_per_example')
cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
tf.add_to_collection('losses', cross_entropy_mean)
# The total loss is defined as the cross entropy loss plus all of the weight
# decay terms (L2 loss).
return tf.add_n(tf.get_collection('losses'), name='total_loss'), accuracy
但是精度不是随时代而增加,而是打印随机数。
2019-11-28 20:31:24.911989: step 0, loss = 4.66, accuracy = 96.094 (10.5 examples/sec; 12.147 sec/batch)
2019-11-28 20:56:09.946520: step 10, loss = 4.61, accuracy = 0.000 (0.1 examples/sec; 868.503 sec/batch)
2019-11-28 21:20:33.645519: step 20, loss = 4.39, accuracy = 7.812 (0.1 examples/sec; 1226.370 sec/batch)
2019-11-28 22:56:20.576587: step 30, loss = 4.39, accuracy = 0.000 (0.1 examples/sec; 1234.693 sec/batch)
我在哪里做错了?