关于tf.Tensor.set_shape()的澄清

时间:2016-02-17 08:57:43

标签: tensorflow

我的图像是478 x 717 x 3 = 1028178像素,等级为1.我通过调用tf.shape和tf.rank验证了它。

当我调用image.set_shape([478,717,3])时,它会抛出以下错误。

"Shapes %s and %s must have the same rank" % (self, other)) 
ValueError: Shapes (?,) and (478, 717, 3) must have the same rank

我通过首次测试再次测试到1028178,但错误仍然存​​在。

ValueError: Shapes (1028178,) and (478, 717, 3) must have the same rank

嗯,这确实有意义,因为一个是等级1而另一个是等级3.但是,为什么有必要抛出错误,因为像素的总数仍然匹配。

我当然可以使用tf.reshape并且它有效,但我认为这不是最佳的。

如TensorFlow常见问题解答

所述
  

x.set_shape()和x = tf.reshape(x)之间有什么区别?

     

tf.Tensor.set_shape()方法更新Tensor的静态形状   物体,通常用于提供额外的形状   这个信息无法直接推断出来。它没有改变   张量的动态形状。

     

tf.reshape()操作会创建一个具有不同动态形状的新张量。

创建一个新的张量涉及内存分配,当涉及更多的训练样例时,这可能会更昂贵。这是设计,还是我在这里遗漏了什么?

1 个答案:

答案 0 :(得分:54)

据我所知(并且我编写了该代码),Tensor.set_shape()中没有错误。我认为这种误解源于该方法令人困惑的名称。

要详细说明FAQ entry you quotedTensor.set_shape()是一个纯Python函数,改进给定tf.Tensor对象的形状信息。通过“改进”,我的意思是“更具体”。

因此,当您拥有形状为Tensor的{​​{1}}对象t时,这是一个未知长度的一维张量。致电(?,)时,您可以拨打t.set_shape((1028178,)),然后t将成型(1028178,)。这不会影响底层存储,或者后端上的任何事情:它仅仅意味着使用t.get_shape()的后续形状推断可以依赖于它是长度为1028178的向量的断言。

如果t的形状为t,则对(?,)的调用将失败,因为TensorFlow已经知道t.set_shape((478, 717, 3))是一个向量,因此不能有形t。如果您想使用(478, 717, 3)的内容制作具有该形状的新Tensor,则可以使用reshaped_t = tf.reshape(t, (478, 717, 3))。这在Python中创建了一个新的t对象;实际的implementation of tf.reshape()使用张量缓冲区的浅副本来做到这一点,所以在实践中它很便宜。

一个类比是tf.Tensor就像Java中面向对象语言中的运行时转换。例如,如果您有一个指向Tensor.set_shape()的指针,但事实上它知道它是Object,您可以执行投射String以传递(String) obj一个期望obj参数的方法。但是,如果您有String String并尝试将其强制转换为s,则编译器会给您一个错误,因为这两种类型不相关。