训练准确性以及张量流损失

时间:2019-11-29 04:35:28

标签: python tensorflow deep-learning

我正在CIFAR-10数据集上运行CNN模型。该代码取自tensorflow的this教程。我需要在每10个时间段损失的同时打印出训练的准确性。因此,loss()函数中计算出的准确性。

def train():
  """Train CIFAR-10 for a number of steps."""
  with tf.Graph().as_default():
    global_step = tf.train.get_or_create_global_step()

    # Get images and labels for CIFAR-10.
    with tf.device('/cpu:0'):
      images, labels = cifar10.distorted_inputs()

    # Build a Graph that computes the logits predictions from the
    # inference model.
    logits = cifar10.inference(images)

    # Calculate loss.
    loss, accuracy = cifar10.loss(logits, labels)

    # Build a Graph that trains the model with one batch of examples and
    # updates the model parameters.
    train_op = cifar10.train(loss, global_step)

    class _LoggerHook(tf.train.SessionRunHook):
      """Logs loss and runtime."""

      def begin(self):
        self._step = -1
        self._start_time = time.time()

      def before_run(self, run_context):
        self._step += 1
        return tf.train.SessionRunArgs({"loss":loss, "acc":accuracy})  # Asks for loss value.

      def after_run(self, run_context, run_values):
        if self._step % FLAGS.log_frequency == 0:
          current_time = time.time()
          duration = current_time - self._start_time
          self._start_time = current_time

          loss_value = run_values.results['loss']
          acc = run_values.results['acc']
          examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
          sec_per_batch = float(duration / FLAGS.log_frequency)

          format_str = ('%s: step %d, loss = %.2f, accuracy = %.3f (%.1f examples/sec; %.3f '
                        'sec/batch)')
          print (format_str % (datetime.now(), self._step, loss_value, acc*100,
                               examples_per_sec, sec_per_batch))

    with tf.train.MonitoredTrainingSession(
        checkpoint_dir=FLAGS.train_dir,
        hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
               tf.train.NanTensorHook(loss),
               _LoggerHook()],
        config=tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement)) as mon_sess:
      while not mon_sess.should_stop():
        mon_sess.run(train_op)

这是loss()函数。前两行计算准确性。

def loss(logits, labels):
  """Add L2Loss to all the trainable variables.

  Add summary for "Loss" and "Loss/avg".
  Args:
    logits: Logits from inference().
    labels: Labels from distorted_inputs or inputs(). 1-D tensor
            of shape [batch_size]

  Returns:
    Loss tensor of type float.
  """
  # Calculate the average cross entropy loss across the batch.
  correct_pred = tf.equal(tf.argmax(logits, -1), tf.argmax(labels, -1))
  accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

  labels = tf.cast(labels, tf.int64)
  cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
      labels=labels, logits=logits, name='cross_entropy_per_example')
  cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
  tf.add_to_collection('losses', cross_entropy_mean)

  # The total loss is defined as the cross entropy loss plus all of the weight
  # decay terms (L2 loss).
  return tf.add_n(tf.get_collection('losses'), name='total_loss'), accuracy

但是精度不是随时代而增加,而是打印随机数。

2019-11-28 20:31:24.911989: step 0, loss = 4.66, accuracy = 96.094 (10.5 examples/sec; 12.147 sec/batch)
2019-11-28 20:56:09.946520: step 10, loss = 4.61, accuracy = 0.000 (0.1 examples/sec; 868.503 sec/batch)
2019-11-28 21:20:33.645519: step 20, loss = 4.39, accuracy = 7.812 (0.1 examples/sec; 1226.370 sec/batch)
2019-11-28 22:56:20.576587: step 30, loss = 4.39, accuracy = 0.000 (0.1 examples/sec; 1234.693 sec/batch)

我在哪里做错了?

0 个答案:

没有答案