tf.shape()在张量流中得到错误的形状

时间:2016-05-07 06:47:18

标签: python python-3.x tensorflow tensor

我定义了这样的张量:

x = tf.get_variable("x", [100])

但是当我尝试打印张量的形状时:

print( tf.shape(x) )

我得到 Tensor(“形状:0”,形状=(1,),dtype = int32),为什么输出结果不应该是shape =(100)

5 个答案:

答案 0 :(得分:106)

tf.shape(input, name=None)返回表示输入形状的1-D整数张量。

您正在寻找:x.get_shape(),它会返回TensorShape变量的x

更新:由于这个答案,我写了一篇文章来澄清Tensorflow中的动态/静态形状:https://pgaleone.eu/tensorflow/2018/07/28/understanding-tensorflow-tensors-shape-static-dynamic/

答案 1 :(得分:10)

澄清:

tf.shape(x)创建一个op并返回一个对象,该对象代表构造的op的输出,这是您当前正在打印的内容。要获得形状,请在会话中运行操作:

matA = tf.constant([[7, 8], [9, 10]])
shapeOp = tf.shape(matA) 
print(shapeOp) #Tensor("Shape:0", shape=(2,), dtype=int32)
with tf.Session() as sess:
   print(sess.run(shapeOp)) #[2 2]
信用:看了上面的答案后,我看到了tf.rank function in Tensorflow的答案,我发现这个答案更有帮助,我在这里试过改写它。

答案 2 :(得分:6)

只是一个简单的例子,说清楚:

a = tf.Variable(tf.zeros(shape=(2, 3, 4)))
print('-'*60)
print("v1", tf.shape(a))
print('-'*60)
print("v2", a.get_shape())
print('-'*60)
with tf.Session() as sess:
    print("v3", sess.run(tf.shape(a)))
print('-'*60)
print("v4",a.shape)

输出将是:

------------------------------------------------------------
v1 Tensor("Shape:0", shape=(3,), dtype=int32)
------------------------------------------------------------
v2 (2, 3, 4)
------------------------------------------------------------
v3 [2 3 4]
------------------------------------------------------------
v4 (2, 3, 4)

这也应该有用: How to understand static shape and dynamic shape in TensorFlow?

答案 3 :(得分:2)

只需使用__add__获取静态形状

tensor.shape

要获得动态形状,请使用In [102]: a = tf.placeholder(tf.float32, [None, 128]) # returns [None, 128] In [103]: a.shape.as_list() Out[103]: [None, 128]

tf.shape()

您也可以使用dynamic_shape = tf.shape(a) 获取NumPy中的形状,如下例所示。

your_tensor.shape

此外,这个例子,对于可以In [11]: tensr = tf.constant([[1, 2, 3, 4, 5], [2, 3, 4, 5, 6]]) In [12]: tensr.shape Out[12]: TensorShape([Dimension(2), Dimension(5)]) In [13]: list(tensr.shape) Out[13]: [Dimension(2), Dimension(5)] In [16]: print(tensr.shape) (2, 5) 的张量。

eval

答案 4 :(得分:0)

Tensorflow 2.0兼容答案 Tensorflow 2.x (>= 2.0) nessuno解决方案的兼容答案如下所示:

x = tf.compat.v1.get_variable("x", [100])

print(x.get_shape())