在TensorFlow急切模式下获得渐变的简便方法是
@tfe.implicit_value_and_gradients
def loss_fn(data):
output = network(data)
loss = f(output)
return loss
...
loss, grads = loss_fn(data)
如果我想同时计算额外数据,例如accuracy
和loss
,我该怎么办?即,我想要像
@tfe.implicit_values_and_gradients_of_first_result
def compute_fn(data):
output = network(data)
loss = f(output)
accuracy = g(output)
return loss, accuracy
...
loss, accuracy, loss_grads = compute_fn(data)
我可以通过将accuracy
值填充到单独的状态变量中来自行模拟。这是最好的方法,还是现有的便利功能使这很容易?
答案 0 :(得分:1)
对于这种情况,最好使用GradientTape
界面。如下所示:
def compute_fn(data):
with tfe.GradientTape() as tape:
output = network(data)
loss = f(output)
accuracy = g(output)
return loss, accuracy, tape.gradients(loss, network.variables)
希望有所帮助。