TensorFlow 2.0中是否可以替代global_step?

时间:2019-03-26 10:48:48

标签: tensorflow2.0

TensorFlow 2.0中似乎缺少global_step。

我有几个对当前培训进度感兴趣的回调,而且我不确定是否需要实现自己的步数计数器或依赖于时期计数...

是否有更换建议?

2 个答案:

答案 0 :(得分:2)

现在最好声明自己的global_step = tf.Variable(1, name="global_step")并手动使用。

在文档中找不到tf.train.get_or_create_global_step的直接替代品,并且文档中与step有关的唯一部分是tf.summary模块的实验部分: https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/summary/experimental

答案 1 :(得分:0)

在TensorFlow 2.3.1及其Keras API中工作。

tf.keras.optimizers.Optimizer的实例继承了iterations属性。 The implementation表示这是一个计数器,在每个训练步骤后都会增加。在编译模型之前,请从优化器访问它。

import tensorflow as tf

tf.compat.v1.disable_eager_execution()  # see note

optimizer = tf.keras.optimizers.Adam()
training_step = optimizer.iterations

model = Model(inputs,outputs)
model.compile(
    loss=my_adaptive_loss_function(training_step),
    optimizer=optimizer)

注意:在我的设置中,我必须禁用急切执行才能使用此变量,否则会收到以下TypeError。如果您的实现方法比我的方法更灵活,则可以避免这种情况。

TypeError: An op outside of the function building code is being passed
a "Graph" tensor. It is possible to have Graph tensors
leak out of the function building context by including a
tf.init_scope in your function building code.
For example, the following function will fail:
  @tf.function
  def has_init_scope():
    my_constant = tf.constant(1.)
    with tf.init_scope():
      added = my_constant * 2
The graph tensor has name: pulse_features:0

During handling of the above exception, another exception occurred: