如何在Tensorflow中提取占位符Tensor的形状值?

时间:2016-12-01 21:01:00

标签: python tensorflow

我为输入数据定义了x = tf.placeholder("float", shape=[None, 784])。稍后,我需要知道x形状的第一个值作为批量大小。我按x.get_shape().as_list()[0]提取值,但我得到了None。你能告诉我怎样才能正确提取它?非常感谢!

编辑:

我现在使用tf.get_shape()但它会导致另一个错误。在我的代码中,我定义了一个deconv功能:

def deconv(X, W, b, output_shape):
    X += b 
    return tf.nn.conv2d_transpose(X, W, output_shape, strides=[1, 1, 1, 1])

如果我以这样的方式将batch_size设置为intbatch_size = 50,则deconv功能的调用效果如下:

W_conv2_T = tf.ones([5, 5, 32, 64])
pool1_tr = deconv(conv2_tr, W_conv2_T, tf.zeros([64]), [batch_size, 14, 14, 32])

conv2_tr的形状为[50, 14, 14, 64]。由此产生的pool1_tr形状为[50, 14, 14, 32]。但是,如果我设置batch_size = tf.get_shape(x)[0]conv2_tr的形状为[None, 14, 14, 64],结果pool1_tr的形状将变为[None, None, None, None]。这个bug很奇怪。你能帮我解决这个问题吗?提前谢谢!

1 个答案:

答案 0 :(得分:6)

占位符中行数的值None表示它在运行时可能会有所不同,因此您必须使用tf.shape(x)将形状设置为tf.Tensor。以下代码应该有效:

batch_size = tf.shape(x)[0]