我试图理解tf代码,为此,我正在打印张量的形状。对于以下代码
print(x.shape)
print(tf.shape(x))
我得到输出
(?, 32, 32, 3)
Tensor("input/Shape:0", shape=(4,), dtype=int32)
这没有多大意义。根据我在网上发现的内容,可以使用tf.shape(x)动态获取批处理的大小。但这会产生非常错误的输出-4.我不确定(4,)
的来源以及如何为张量获取正确的值。
答案 0 :(得分:1)
实际上,两个结果是相同的。 4
的形状为(?,32,32,3)
。
x.shape()
返回一个元组,您可以不用sess.run()
来获得形状。您可以使用as_list()
将其转换为列表。
tf.shape(x)
返回一个张量,并且您需要运行sess.run()
以获取实际的数字。
一个例子:
import tensorflow as tf
import numpy as np
x = tf.placeholder(shape=(None,32,32,3),dtype=tf.float32)
print(x.shape)
print(tf.shape(x))
dim = tf.shape(x)
dim0 = tf.shape(x)[0]
with tf.Session()as sess:
dim,dim0 = sess.run([dim,dim0],feed_dict={x:np.random.uniform(size=(100,32,32,3))})
print(dim)
print(dim0)
#print
(?, 32, 32, 3)
Tensor("Shape:0", shape=(4,), dtype=int32)
[100 32 32 3]
100