tf.contrib.framework.get_global_step()的说明

时间:2017-06-05 22:26:52

标签: tensorflow

我正在尝试使用tensorflow编程,我遇到了这个函数:

global_step = tf.contrib.framework.get_global_step()

有人可以向我解释这里究竟发生了什么吗?我在tensorflow的文档中找到了这个解释,但对我来说并不是很清楚。

  

global_step:一个整数变量,表示为每个模型训练运行增加的步数计数器。可以通过get_global_step()函数在TensorFlow中轻松创建/递增。

其中get_global_step返回全局张量。

非常感谢!

1 个答案:

答案 0 :(得分:1)

返回global_step变量。

据我了解,此变量用于跟踪当前(全局)训练步骤,即当您将其传递给优化器时,每次对参数进行更新时,它们都会递增。

根据tf.train.Optimizer.minimize()的定义,您可以看到它是如何工作的:

  

global_step:可选变量在变量更新后递增1。

一个other use case是保存检查点:

saver.save(sess, FLAGS.train_dir, global_step=step)

PS:

  1. 您需要先定义它,然后才能调用此功能。即:global_step = tf.Variable(0, name="global_step", trainable=False)

  2. 如果您在同一会话中进行多项培训,则此变量将由所有优化器增加。