如何检查张量流图中是否执行了`compute_gradients`操作?

时间:2018-05-10 20:03:55

标签: python tensorflow

这是我的用例
我正在尝试实现Model Agnostic Meta Learning算法。在训练过程的某个阶段,我需要计算一些变量的梯度而不实际更新变量,在后面的步骤中我只想在计算梯度操作完成时做某些事情。

执行此操作的一种简单方法是使用tf.control_dependencies()

# In this step I would like to COMPUTE gradients
optimizer = tf.train.AdamOptimizer()
# let's assume that I already have loss and var_list
gradients = optimizer.compute_gradients(loss, var_list)

# In this step I would like to do some things ONLY if the gradients are computed
with tf.control_dependencies([gradients]):
    # do some stuff

问题
遗憾的是,上述代码段引发了错误,因为tf.control_dependencies期望gradientstf.Operationtf.Tensor,但compute_gradients返回list

错误讯息:
TypeError: Can not convert a list into a Tensor or Operation.

我想要什么?
我想要两件事之一:

  • 我可以通过tf.Operation中可以使用tf.Tensor函数获取optimizer.compute_gradientstf.control_dependencies的方法。
  • 或者我可以通过其他任何可靠的方式检查optimizer.compute_gradients是否实际计算过。

1 个答案:

答案 0 :(得分:1)

由于gradients是您希望确定计算的(渐变,变量)对的列表,因此您可以将其转换为张量/变量列表并将其用作control_inputs

with tf.control_dependencies([t for tup in gradients for t in tup]):