为什么Tensorflow重塑tf.reshape()会破坏渐变的流动?

时间:2017-06-30 00:20:00

标签: python tensorflow

我正在创建一个tf.Variable(),然后使用该变量创建一个简单的函数,然后我使用tf.reshape()展平原始变量然后我在函数和展平之间取tf.gradients()变量。为什么返回[无]。

var = tf.Variable(np.ones((5,5)), dtype = tf.float32)
f = tf.reduce_sum(tf.reduce_sum(tf.square(var)))
var_f = tf.reshape(var, [-1])
print tf.gradients(f,var_f)

执行上面的代码块返回[None]。这是一个错误吗?请帮助!

1 个答案:

答案 0 :(得分:4)

您正在查找f相对于var_f的衍生产品,但f不是var_f的函数,而是var的函数。这就是为什么你得到[无]。现在,如果您将代码更改为:

 var = tf.Variable(np.ones((5,5)), dtype = tf.float32)
 var_f = tf.reshape(var, [-1])
 f = tf.reduce_sum(tf.reduce_sum(tf.square(var_f)))
 grad = tf.gradients(f,var_f)
 print(grad)

您的渐变将被定义:

  

tf.Tensor'gradients_28 / Square_32_grad / mul_1:0'shape =(25,)dtype = float32>

以下代码的图表可视化如下:

 var = tf.Variable(np.ones((5,5)), dtype = tf.float32, name='var')
 f = tf.reduce_sum(tf.reduce_sum(tf.square(var)), name='f')
 var_f = tf.reshape(var, [-1], name='var_f')
 grad_1 = tf.gradients(f,var_f, name='grad_1')
 grad_2 = tf.gradients(f,var, name='grad_2')

enter image description here

grad_1的导数未定义,而'grad_2`定义。显示了两个梯度的反向传播图(梯度图)。