Tensorflow,更新变量以具有任意形状

时间:2017-05-25 23:03:08

标签: tensorflow

所以,根据documentation,我们可以使用tf.assign和validate_shape = False来改变形状。它确实改变了变量内容的形状,但是你可以从get_shape()获得的形状不会更新。例如:

>>> a = tf.Variable([1, 1, 1, 1])
>>> sess.run(tf.global_variables_initializer())
>>> tf.assign(a, [[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1]], validate_shape=False).eval()
array([[1, 1, 1, 1, 1, 1, 1],
       [1, 1, 1, 1, 1, 1, 1]], dtype=int32)
>>> a.get_shape()
TensorShape([Dimension(4)])

令人讨厌的是,网络的后续层将其形状基于此变量的get_shape()值。因此,即使实际形状是正确的,Tensorflow也会抱怨尺寸不匹配。所以关于如何更新"相信"每个变量的形状?

2 个答案:

答案 0 :(得分:1)

简而言之:使用set_shape更新变量的静态形状。

通过阅读TF FAQ

,您可以了解正在发生的事情
  

在TensorFlow中,张量具有静态(推断)形状和a   动态(真实)形状。可以使用以下方法读取静态形状   tf.Tensor.get_shape方法:从操作中推断出这种形状   用于创建张量的,可能是部分完成的。如果   静态形状没有完全定义,Tensor t的动态形状   可以通过评估tf.shape(t)来确定。

所以静态形状没有被正确推断,你应该给TF一个提示。幸运的是,同一个FAQ中的下几行告诉你该怎么做:

  

tf.Tensor.set_shape方法更新Tensor的静态形状   物体,通常用于提供额外的形状   这个信息无法直接推断出来。它没有改变   张量的动态形状。

答案 1 :(得分:0)

由于validate_shape设置为false,因此变量的静态形状不会在图形中自动更新。 wrokaround是用新的形状(已知的)

设置它
a = tf.Variable([1, 1, 1, 1], validate_shape=False)
sess.run(tf.global_variables_initializer())
new_arr_assign = np.array([[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1]])
tf.assign(a, new_arr_assign, validate_shape=False).eval(session=sess)
a.set_shape(new_arr_assign.shape)
a.get_shape()
# results: TensorShape([Dimension(2), Dimension(7)])