使用TxK列索引数组从TxN numpy数组中选择TxK numpy数组

时间:2014-10-06 18:52:50

标签: python numpy indexing

这是间接索引问题。

可以通过列表理解来解决。

问题是,是否,或者如何在numpy内解决它,

当   data.shape(T,N) 和   c.shape(T,K)

c的每个元素都是0到N-1之间的int,也就是说, c的每个元素旨在引用data中的列号。

目标是获得out其中

out.shape = (T,K)

i

中的每个0..(T-1)

out[i] = [ data[i, c[i,0]] , ... , data[i, c[i,K-1]] ]

具体例子:

data = np.array([\
       [ 0,  1,  2],\
       [ 3,  4,  5],\
       [ 6,  7,  8],\
       [ 9, 10, 11],\
       [12, 13, 14]])

c = np.array([
      [0, 2],\
      [1, 2],\
      [0, 0],\       
      [1, 1],\       
      [2, 2]])

out should be out = [[0, 2], [4, 5], [6, 6], [10, 10], [14, 14]]

out的第一行是[0,2],因为所选列由c的第0行给出,它们是0和2,第0和第2列的数据[0]是0和2。

第二行输出是[4,5]因为所选列由c的第1行给出,它们是1和2,第1列和第2列的数据[1]是4和5。

Numpy花式索引似乎并没有以明显的方式解决这个问题,因为使用c(例如data[c]np.take(data,c,axis=1))索引数据总是产生3维数组。

列表理解可以解决它:

out = [ [data[rowidx,i1],data[rowidx,i2]] for (rowidx, (i1,i2)) in enumerate(c) ]

如果K为2,我认为这是可以接受的。如果K是可变的,那就不太好了。

必须为每个值K重写列表推导,因为它会从data的每一行中展开从c中挑选出的列。它也违反了DRY。

是否有完全基于numpy的解决方案?

2 个答案:

答案 0 :(得分:2)

您可以使用np.choose来避免循环:

In [1]: %cpaste
Pasting code; enter '--' alone on the line to stop or use Ctrl-D.

data = np.array([\
       [ 0,  1,  2],\
       [ 3,  4,  5],\
       [ 6,  7,  8],\
       [ 9, 10, 11],\
       [12, 13, 14]])

c = np.array([
      [0, 2],\
      [1, 2],\
      [0, 0],\
      [1, 1],\
      [2, 2]])
--

In [2]: np.choose(c, data.T[:,:,np.newaxis])
Out[2]: 
array([[ 0,  2],
       [ 4,  5],
       [ 6,  6],
       [10, 10],
       [14, 14]])

答案 1 :(得分:1)

这是通用解决方案的一种可能途径......

data创建掩码,为out的每列选择值。例如,第一个掩码可以通过写:

来实现
>>> np.arange(3) == np.vstack(c[:,0])
array([[ True, False, False],
       [False,  True, False],
       [ True, False, False],
       [False,  True, False],
       [False, False,  True]], dtype=bool)

>>> data[_]
array([ 2,  5,  6, 10, 14])

获取out第二列值的掩码:np.arange(3) == np.vstack(c[:,1])

所以,要获得out数组......

>>> mask0 = np.arange(3) == np.vstack(c[:,0])
>>> mask1 = np.arange(3) == np.vstack(c[:,1])
>>> np.vstack((data[mask0], data[mask1])).T
array([[ 0,  2],
       [ 4,  5],
       [ 6,  6],
       [10, 10],
       [14, 14]])

编辑:给定任意数组宽度KN,您可以使用循环来创建蒙版,因此out数组的一般构造可能看起来像这样:

np.vstack([data[np.arange(N) == np.vstack(c[:,i])] for i in range(K)]).T

编辑2 :一个稍微整洁的解决方案(虽然仍然依赖于循环)是:

np.vstack([data[i][c[i]] for i in range(T)])