急切模式下的@ tf.custom_gradient具有复杂的参数

时间:2018-08-29 14:15:24

标签: tensorflow

我有一个函数需要定义一个自定义渐变,但是我只需要针对作为嵌套张量列表的参数之一进行渐变。
我重新创建了一个描述我的问题的简单示例:

import tensorflow as tf
from tensorflow.contrib.eager.python import tfe
tf.enable_eager_execution()

@tf.custom_gradient
def function(const, var, iw):
    loss = const["next"]*(var[0] + tf.reduce_sum(var[1])*const["prev"] + iw*var[2])

    def grad(dy):
        return [0.0, dy * 2.0 * var, 0.0]
    return loss, grad


grads = tfe.implicit_value_and_gradients(function)
var = [tf.convert_to_tensor(13.0), [tf.convert_to_tensor(6.0), tf.convert_to_tensor(7.0)], tf.convert_to_tensor(3.0)]
const = {"next": 3.0, "prev": 2.0}
iw = 0.1
loss, grads_and_vars = grads(const, var, iw)
print(loss, grads_and_vars)

在这里,我仅需要相对于function的{​​{1}}的渐变,这是一个嵌套列表(表示由RNN计算的晶格结构上的成本)。

我有两个问题:
首先:tensorflow抱怨参数之一是var
 dict

第二:如果我以某种方式摆脱了此参数,则嵌套列表会出现更多问题: ValueError: Attempt to convert a value ({'next': 3.0, 'prev': 2.0}) with an unsupported type (<class 'dict'>) to a Tensor.

我想做的是针对tensorflow.python.framework.errors_impl.InvalidArgumentError: Shapes of all inputs must match: values[0].shape = [] != values[1].shape = [2] [Op:Pack] name: packed/ 相对于var的梯度,返回具有与function相同结构的嵌套列表。

0 个答案:

没有答案