切片是为我的数组添加第三维 - 不知道为什么

时间:2016-08-28 19:07:27

标签: python numpy

我试图用仅等于2和3的标签来索引testdata。但是,当我运行此代码时,它将我的数组从2D(100 x 100)转换为3D(100 x 1 x 100)。 / p>

任何人都可以解释为什么会这样做吗?代码中的最后一行是罪魁祸首,但我不确定为什么会发生这种情况。

labels = testdata[:,0]
num2 = numpy.nonzero(labels == 2)
num2 = numpy.transpose(num2)
num3 = numpy.nonzero(labels == 3)
num3 = numpy.transpose(num3)
num = numpy.vstack([num2,num3])
testdata = testdata[num,:]

1 个答案:

答案 0 :(得分:1)

当有谜题时,打印中间值。更好的是,在交互式shell中运行测试用例,以便检查每个值,并了解正在发生的事情。跟踪形状。

看起来labels是一组数字,如:

In [212]: labels=np.array([0,1,2,2,3,2,0,3,2])

labels为2或3的索引:

In [213]: num2=np.nonzero(labels==2)
In [214]: num2
Out[214]: (array([2, 3, 5, 8], dtype=int32),)
In [215]: num3=np.nonzero(labels==3)

这是关键步骤 - transpose的目的是什么。请注意num2是一个带有一个1d数组的元组。

In [216]: num2=np.transpose(num2)
In [217]: num3=np.transpose(num3)
In [218]: num2
Out[218]: 
array([[2],
       [3],
       [5],
       [8]], dtype=int32)

转置后num2是一个列数组,(4,1)形状。

垂直连接它们会产生一个(6,1)数组:

In [220]: num=np.vstack([num2,num3])
In [221]: num
Out[221]: 
array([[2],
       [3],
       [5],
       [8],
       [4],
       [7]], dtype=int32)
In [222]: num.shape
Out[222]: (6, 1)
In [223]: labels[num]
Out[223]: 
array([[2],
       [2],
       [2],
       [2],
       [3],
       [3]])
In [224]: labels[num].shape
Out[224]: (6, 1)

使用该数组索引1d数组会生成另一个与索引形状相同的数组。索引x[num,:]执行相同的操作,但添加了最后一个维度。

如果我在第一维中用(2,5)数组索引(3,4)数组,结果是(2,5,4)数组:

In [227]: np.ones((3,4))[np.ones((2,5),int),:].shape
Out[227]: (2, 5, 4)