TensorFlow传递梯度未加工

时间:2016-10-28 12:49:55

标签: python numpy tensorflow

假设我在神经网络中使用了一些自定义操作binarizer。该操作需要Tensor并构造一个新的Tensor。我想修改该操作,使其仅用于前向传递。在向后传递中,当计算渐变时,它应该通过到达它的渐变。

更具体地说,binarizer是:

def binarizer(input):
    prob = tf.truediv(tf.add(1.0, input), 2.0)
    bernoulli = tf.contrib.distributions.Bernoulli(p=prob, dtype=tf.float32)
    return 2 * bernoulli.sample() - 1

我设置了我的网络:

# ...

h1_before_my_op = tf.nn.tanh(tf.matmul(x, W) + bias_h1)
h1 = binarizer(h1_before_b)

# ...

loss = tf.reduce_mean(tf.square(y - y_true))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(loss)

如何告诉TensorFlow在向后传递中跳过渐变计算?

我尝试按照this answer中的说明定义自定义操作,但是:py_func无法返回Tensor,这不是它的用途 - 我得到:< / p>

  

UnimplementedError(参见上面的回溯):不支持的对象类型Tensor

1 个答案:

答案 0 :(得分:1)

您正在寻找tf.stop_gradient(input, name=None)

  

停止渐变计算。

     

在图表中执行时,此操作按原样输出其输入张量。

h1 = binarizer(h1_before_b)
h1 = tf.stop_gradient(h1)