在Tensorflow中删除张量的尺寸

时间:2018-09-22 03:37:27

标签: python tensorflow slice tensor

我有一个张量为(50, 100, 1, 512)的张量,我想对其重塑形状或删除第三个维度,以便新张量为(50, 100, 512)

我用tf.slice尝试过tf.squeeze

a = tf.slice(a, [50, 100, 1, 512], [50, 100, 1, 512])
b = tf.squeeze(a)

当我尝试打印ab的形状时,一切似乎都正常,但是当我开始训练模型时,这个错误就出现了

tensorflow.python.framework.errors_impl.InvalidArgumentError: Expected size[0] in [0, 0], but got 50
     [[Node: Slice = Slice[Index=DT_INT32, T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"](MaxPool_2, Slice/begin, Slice/size)]]

我的slice有问题吗?我该如何解决。谢谢

3 个答案:

答案 0 :(得分:3)

通常tf.squeeze将删除尺寸。

a = tf.constant([[[1,2,3],[3,4,5]]])

以上张量形状为[1,2,3]。进行挤压操作后,

b = tf.squeeze(a)

现在,张量形状为[2,3]

答案 1 :(得分:1)

有多种方法可以做到这一点。 Tensorflow已开始支持索引编制。尝试

a = a[:,:,0,:]

OR

a = a[:,:,-1,:]

OR

a = tf.reshape(a,[50,100,512])

答案 2 :(得分:0)

在这种情况下,我使用tf.slice错误,应该是这样的:

a = tf.slice(a, [0, 0, 0, 0], [50, 100, 1, 512])
b = tf.squeeze(a)

您可以通过查看tf.slice documentation

来找出原因