有人可以从张量流中的cifar10教程解释cifar10_train.py中的列车功能

时间:2017-10-17 11:00:54

标签: machine-learning tensorflow

我正在关注来自https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10的cifar10教程。 在这个项目中,有6个班级。在搜索互联网后,我了解了cifar10.py和cifar10_input.py类。但我无法理解cifar10_train.py中的列车功能。这是cifar10_train.py类中的train函数。

def train():
with tf.Graph().as_default():
    global_step = tf.contrib.framework.get_or_create_global_step()

    # get images and labels for cifar 10
    # Force input pipeline to CPU:0 to avoid operations sometime ending on
    # GPU and resulting in a slow down

    with tf.device('/cpu:0'):
        images, labels = cifar10.distorted_inputs()

    logits = cifar10.inference(images)

    loss = cifar10.loss(logits, labels)

    train_op = cifar10.train(loss, global_step)

    class _LoggerHook(tf.train.SessionRunHook):

        def begin(self):
            self._step = -1
            self._start_time = time.time()

        def before_run(self, run_context):
            self._step += 1
            return tf.train.SessionRunArgs(loss)

        def after_run(self, run_context, run_values):
            if self._step % FLAGS.log_frequency == 0:
                current_time = time.time()
                duration = current_time - self._start_time
                self._start_time = current_time

                loss_value = run_values.results
                examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                sec_per_batch = float(duration / FLAGS.log_frequency)

                format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                              'sec/batch)')
                print(format_str % (datetime.now(), self._step, loss_value,
                                    examples_per_sec, sec_per_batch))

    with tf.train.MonitoredTrainingSession(
            checkpoint_dir=FLAGS.train_dir,
            hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                   tf.train.NanTensorHook(loss),
                   _LoggerHook()],
            config=tf.ConfigProto(
                log_device_placement=FLAGS.log_device_placement)) as mon_sess:
        while not mon_sess.should_stop():
            mon_sess.run(train_op)

有人可以解释_LoggerHook类中发生的事情吗?

1 个答案:

答案 0 :(得分:0)

它使用SessionRunHook_LoggerHook来记录训练时的损失。

SessionRunHook call hooks.begin() sess = tf.Session() call hooks.after_create_session() while not stop is requested: call hooks.before_run() try: results = sess.run(merged_fetches, feed_dict=merged_feeds) except (errors.OutOfRangeError, StopIteration): break call hooks.after_run() call hooks.end() sess.close() 的实现,按以下顺序运行:

loss

来自here

它会在session.run之前收集loss个数据,然后以预定义的格式输出heroku pg:psql

教程:https://www.tensorflow.org/tutorials/layers

希望这是希望。