为什么自动区分和渐变磁带需要使用上下文管理器?

时间:2019-03-23 04:42:17

标签: python tensorflow automatic-differentiation

上下文管理器可以将两个两个相关的操作更改为一个。例如:

with open('some_file', 'w') as opened_file:
    opened_file.write('Hola!')

上面的代码等效于:

file = open('some_file', 'w')
try:
    file.write('Hola!')
finally:
    file.close()

但是在https://www.tensorflow.org/tutorials/eager/custom_training_walkthrough#define_the_loss_and_gradient_function中 我发现:

def grad(model, inputs, targets):
  with tf.GradientTape() as tape:
    loss_value = loss(model, inputs, targets)
  return loss_value, tape.gradient(loss_value, model.trainable_variables)

它等同于什么?

1 个答案:

答案 0 :(得分:1)

我不是python专家,但是我认为with是由__enter__方法和__exit__方法(https://book.pythontips.com/en/latest/context_managers.html)定义的。 对于tf.GradientTape方法,__enter__是:

  def __enter__(self):
    """Enters a context inside which operations are recorded on this tape."""
    self._push_tape()
    return self

https://github.com/tensorflow/tensorflow/blob/r2.0/tensorflow/python/eager/backprop.py#L801-L804

以及__exit__方法

  def __exit__(self, typ, value, traceback):
    """Exits the recording context, no further operations are traced."""
    if self._recording:
      self._pop_tape()

https://github.com/tensorflow/tensorflow/blob/r2.0/tensorflow/python/eager/backprop.py#L806-L809

然后

with tf.GradientTape() as tape:
    loss_value = loss(model, inputs, targets)

tape = tf.GradientTape()
tape.push_tape()
loss_value = loss(model, inputs, targets)
self._pop_tape()