从具有索引列表的多维数组中选择

时间:2019-12-01 08:49:38

标签: python python-3.x numpy indexing numpy-slicing

假设我有一个大小为batch x max_len x output_size的数组,其中batchmax_lenoutput_size都对应到正的自然数。我有一个索引列表,这些索引对应于维度1中的各个项目(即max_len)。给定这些索引,如何从数组中选择?

作为一个具体的例子,假设我有以下内容:

>>> l = np.random.randn(4,5,6)
>>> l.shape
(4, 5, 6)
>>> idx = [0,0,2,3]

在给定l的情况下选择idx时,我得到:

>>> l[:,idx,:].shape
(4, 4, 6)
>>>

我也尝试过np.take,但达到了相同的结果:

>>> np.take(l,idx,axis=1).shape
(4, 4, 6)
>>> 

但是,我要处理的输出是(4,1,6),因为我试图只查看batch中的每个元素(即第一维)。如何产生形状合适的输出?

1 个答案:

答案 0 :(得分:2)

idx扩展为与l相同的单位后使用np.take_along_axis-

np.take_along_axis(l,np.asarray(idx)[:,None,None],axis=1)

使用显式整数数组索引-

l[np.arange(len(idx)),idx][:,None] # skip [:,None] for (4,6) shaped o/p