如何理解TensorFlow中的静态形状和动态形状?

时间:2016-05-08 04:31:00

标签: tensorflow

TensorFlow FAQ中,它说:

  

在TensorFlow中,张量具有静态(推断)形状和a   动态(真实)形状。可以使用以下方法读取静态形状   tf.Tensor.get_shape()方法:这个形状是从中推断出来的   用于创建张量的操作,可能是部分操作   完成。如果静态形状没有完全定义,则动态形状   可以通过评估tf.shape(t)来确定张量t。

但我仍然无法完全理解静态形状和动态形状之间的关系。是否有任何显示其差异的例子?感谢。

3 个答案:

答案 0 :(得分:65)

有时张量的形状取决于在运行时计算的值。让我们看一下以下示例,其中x被定义为具有四个元素的tf.placeholder()向量:

x = tf.placeholder(tf.int32, shape=[4])
print x.get_shape()
# ==> '(4,)'

x.get_shape()的值是x的静态形状,而(4,)表示它是长度为4的向量。现在让我们应用{{ 3}}选择x

y, _ = tf.unique(x)
print y.get_shape()
# ==> '(?,)'

(?,)表示y是未知长度的向量。为什么不知道? tf.unique()会返回x中的唯一值,x的值未知,因为它是tf.placeholder(),所以在您输入之前它没有值它。让我们看看如果您提供两个不同的值会发生什么:

sess = tf.Session()
print sess.run(y, feed_dict={x: [0, 1, 2, 3]}).shape
# ==> '(4,)'
print sess.run(y, feed_dict={x: [0, 0, 0, 0]}).shape
# ==> '(1,)'

希望这表明张量可以具有不同的静态和动态形状。动态形状始终是完全定义的 - 它没有?尺寸 - 但静态形状可能不太具体。这使得TensorFlow能够支持tf.unique()tf.unique(x)等操作,这些操作可以具有可变大小的输出,并且可以在高级应用程序中使用。

最后,tf.dynamic_partition() op可用于获取张量的动态形状并在TensorFlow计算中使用它:

z = tf.shape(y)
print sess.run(z, feed_dict={x: [0, 1, 2, 3]})
# ==> [4]
print sess.run(z, feed_dict={x: [0, 0, 0, 0]})
# ==> [1]

答案 1 :(得分:2)

Tensorflow 2.0兼容答案:为了社区的利益,在 Tensorflow Version 2.x (> 2.0) 中提及mrry在其答案中指定的代码。

# Installing the Tensorflow Version 2.1
!pip install tensorflow==2.1

# If we don't Disable the Eager Execution, usage of Placeholder results in RunTimeError

tf.compat.v1.disable_eager_execution()

x = tf.compat.v1.placeholder(tf.int32, shape=[4])
print(x.get_shape())

# ==> 4

y, _ = tf.unique(x)
print(y.get_shape())

# ==> (None,)

sess = tf.compat.v1.Session()
print(sess.run(y, feed_dict={x: [0, 1, 2, 3]}).shape)
# ==> '(4,)'
print(sess.run(y, feed_dict={x: [0, 0, 0, 0]}).shape)
# ==> '(1,)'

z = tf.shape(y)
print(sess.run(z, feed_dict={x: [0, 1, 2, 3]}))
# ==> [4]
print(sess.run(z, feed_dict={x: [0, 0, 0, 0]}))
# ==> [1]

答案 2 :(得分:0)

在上面的答案中定义很好,投票赞成。我还经历了一些观察,所以我想分享。

tf.Tensor.get_shape()可用于使用创建它的操作来推断输出,这意味着我们无需使用sess.run()(运行该操作)即可推断出输出,如名称静态形状所示。 例如,

  

c = tf.random_uniform([1,3,1,1])

是一个tf.Tensor,我们想在运行图形之前的任何步骤中了解它的形状,因此我们可以使用

  

c.get_shape()

tf.Tensor.get_shape无法动态(sess.run())的原因是由于输出类型 TensorShape而不是tf.tensor,输出TensorShape会限制sess.run()的使用。

  

sess.run(c.get_shape())

如果执行此操作,则会收到错误消息,表明TensorShape的类型无效,它必须是Tensor / operation或字符串。

另一方面,动态形状需要通过sess.run()运行操作才能获得形状

  

sess.run(tf.shape(c))

     

输出:array([1、3、1、1])

     

     

sess.run(c).shape

     

(1、3、1、1)#个元组

希望它有助于阐明张量流的概念。