获取可变批量维度的大小

时间:2016-07-07 01:06:05

标签: tensorflow

假设网络的输入是具有可变批量大小的placeholder,即:

x = tf.placeholder(..., shape=[None, ...])

喂食后可以获得x的形状吗? tf.shape(x)[0]仍然会返回None

2 个答案:

答案 0 :(得分:15)

如果x具有可变的批量大小,获得实际形状的唯一方法是使用tf.shape()运算符。此运算符返回tf.Tensor中的符号值,因此可以将其用作其他TensorFlow操作的输入,但要获取形状的具体Python值,需要将其传递给Session.run()

x = tf.placeholder(..., shape=[None, ...])
batch_size = tf.shape(x)[0]  # Returns a scalar `tf.Tensor`

print x.get_shape()[0]  # ==> "?"

# You can use `batch_size` as an argument to other operators.
some_other_tensor = ...
some_other_tensor_reshaped = tf.reshape(some_other_tensor, [batch_size, 32, 32])

# To get the value, however, you need to call `Session.run()`.
sess = tf.Session()
x_val = np.random.rand(37, 100, 100)
batch_size_val = sess.run(batch_size, {x: x_val})
print x_val  # ==> "37"

答案 1 :(得分:0)

您可以使用x.get_shape().as_list()获取张量x的形状。要获得第一个维度(批量大小),您可以使用x.get_shape().as_list()[0]