我正在关注来自https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10的cifar10教程。 在这个项目中,有6个班级。在搜索互联网后,我了解了cifar10.py和cifar10_input.py类。但我无法理解cifar10_train.py中的列车功能。这是cifar10_train.py类中的train函数。
def train():
with tf.Graph().as_default():
global_step = tf.contrib.framework.get_or_create_global_step()
# get images and labels for cifar 10
# Force input pipeline to CPU:0 to avoid operations sometime ending on
# GPU and resulting in a slow down
with tf.device('/cpu:0'):
images, labels = cifar10.distorted_inputs()
logits = cifar10.inference(images)
loss = cifar10.loss(logits, labels)
train_op = cifar10.train(loss, global_step)
class _LoggerHook(tf.train.SessionRunHook):
def begin(self):
self._step = -1
self._start_time = time.time()
def before_run(self, run_context):
self._step += 1
return tf.train.SessionRunArgs(loss)
def after_run(self, run_context, run_values):
if self._step % FLAGS.log_frequency == 0:
current_time = time.time()
duration = current_time - self._start_time
self._start_time = current_time
loss_value = run_values.results
examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
sec_per_batch = float(duration / FLAGS.log_frequency)
format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
'sec/batch)')
print(format_str % (datetime.now(), self._step, loss_value,
examples_per_sec, sec_per_batch))
with tf.train.MonitoredTrainingSession(
checkpoint_dir=FLAGS.train_dir,
hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
tf.train.NanTensorHook(loss),
_LoggerHook()],
config=tf.ConfigProto(
log_device_placement=FLAGS.log_device_placement)) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op)
有人可以解释_LoggerHook类中发生的事情吗?
答案 0 :(得分:0)
它使用SessionRunHook
和_LoggerHook
来记录训练时的损失。
SessionRunHook
是 call hooks.begin()
sess = tf.Session()
call hooks.after_create_session()
while not stop is requested:
call hooks.before_run()
try:
results = sess.run(merged_fetches, feed_dict=merged_feeds)
except (errors.OutOfRangeError, StopIteration):
break
call hooks.after_run()
call hooks.end()
sess.close()
的实现,按以下顺序运行:
loss
来自here。
它会在session.run
之前收集loss
个数据,然后以预定义的格式输出heroku pg:psql
。
教程:https://www.tensorflow.org/tutorials/layers
希望这是希望。