在TensorFlow中将批量张量重塑为批量矢量

时间:2017-04-07 18:16:42

标签: python tensorflow reshape dimension

在创建计算图表期间,我有一个张量x,例如形状为[-1, a, b, c],我想将其重塑为[-1, a*b*c] 我试着这样做:

n = functools.reduce(operator.mul, x.shape[1:], 1)
tf.reshape(x, [-1, n])

但我收到了错误:

TypeError: unsupported operand type(s) for *: 'int' and 'Dimension'

我的问题是:TensorFlow有什么东西要做这个操作吗?

1 个答案:

答案 0 :(得分:3)

正如错误消息告诉您的那样,类型存在问题。如果您创建TensorFlow占位符,例如与

>>> import tensorflow as tf
>>> x = tf.placeholder(tf.float16, shape=(None, 3,7,4))

并在其上调用shape,然后返回值为

>>> x.shape
TensorShape([Dimension(None), Dimension(3), Dimension(7), Dimension(4)])

并且每个元素都是

>>> x.shape[1]
<class 'tensorflow.python.framework.tensor_shape.Dimension'>

即。一个Dimension类的TensorFlow。当然,operator.mul函数不知道如何处理这种类型。幸运的是,tf.TensorShape有一个as_list()函数,它将形状作为整数列表返回。

>>> x.shape.as_list()
[None, 3, 7, 4]

有了它,您可以像以前一样计算元素n的数量:

>>> import functools, operator
>>> n = functools.reduce(operator.mul, x.shape.as_list()[1:], 1)
>>> n 
84
>>> y = tf.reshape(x, [-1, n])