使用条件,最大值从多维数组中提取索引

时间:2013-09-14 07:25:24

标签: python arrays numpy indexing

我有一个5维矩阵。

我希望能够提取满足条件的某个(2d)切片的每一行中第一个值的索引,然后使用这些索引来提取另一个切片中相应索引的值。

以下是我的例子:

In [3]: g = np.arange(48400).reshape(20,11,11,2,10)

我正在使用的两个切片是:

In [4]: sliceA =  g[0,:,:,0,0]

In [5]: sliceA
Out[5]: 
array([[   0,   20,   40,   60,   80,  100,  120,  140,  160,  180,  200],
       [ 220,  240,  260,  280,  300,  320,  340,  360,  380,  400,  420],
       [ 440,  460,  480,  500,  520,  540,  560,  580,  600,  620,  640],
       [ 660,  680,  700,  720,  740,  760,  780,  800,  820,  840,  860],
       [ 880,  900,  920,  940,  960,  980, 1000, 1020, 1040, 1060, 1080],
       [1100, 1120, 1140, 1160, 1180, 1200, 1220, 1240, 1260, 1280, 1300],
       [1320, 1340, 1360, 1380, 1400, 1420, 1440, 1460, 1480, 1500, 1520],
       [1540, 1560, 1580, 1600, 1620, 1640, 1660, 1680, 1700, 1720, 1740],
       [1760, 1780, 1800, 1820, 1840, 1860, 1880, 1900, 1920, 1940, 1960],
       [1980, 2000, 2020, 2040, 2060, 2080, 2100, 2120, 2140, 2160, 2180],
       [2200, 2220, 2240, 2260, 2280, 2300, 2320, 2340, 2360, 2380, 2400]])

和我单独制作的一个,然后加入(为了说明目的):

In [6]: sliceB = np.array([[  3,  12,  21,  31,  41,  51,  69,  77,  83,  91, 100],
   ...:                  [  6,  12,  23,  33,  43,  51,  69,  77,  83,  91, 100],
   ...:                  [  8,  12,  27,  37,  47,  51,  69,  77,  83,  91, 100],
   ...:                  [  4,  12,  28,  38,  48,  51,  69,  77,  83,  91, 100],
   ...:                  [  7,  12,  29,  39,  49,  51,  69,  77,  83,  91, 100],
   ...:                  [  9,  12,  22,  32,  42,  51,  69,  77,  83,  91, 100],
   ...:                  [  6,  12,  21,  31,  41,  51,  69,  77,  83,  91, 100],
   ...:                  [  8,  12,  25,  35,  45,  51,  69,  77,  83,  91, 100],
   ...:                  [  5,  12,  26,  36,  46,  51,  69,  77,  83,  91, 100],
   ...:                  [  7,  12,  22,  32,  42,  51,  69,  77,  83,  91, 100],
   ...:                  [  3,  12,  24,  34,  44,  51,  69,  77,  83,  91, 100]])

In [11]: g[0,:,:,0,1] = sliceB 

In [12]: g[0,:,:,0,1]
Out[12]: 
array([[  3,  12,  21,  31,  41,  51,  69,  77,  83,  91, 100],
       [  6,  12,  23,  33,  43,  51,  69,  77,  83,  91, 100],
       [  8,  12,  27,  37,  47,  51,  69,  77,  83,  91, 100],
       [  4,  12,  28,  38,  48,  51,  69,  77,  83,  91, 100],
       [  7,  12,  29,  39,  49,  51,  69,  77,  83,  91, 100],
       [  9,  12,  22,  32,  42,  51,  69,  77,  83,  91, 100],
       [  6,  12,  21,  31,  41,  51,  69,  77,  83,  91, 100],
       [  8,  12,  25,  35,  45,  51,  69,  77,  83,  91, 100],
       [  5,  12,  26,  36,  46,  51,  69,  77,  83,  91, 100],
       [  7,  12,  22,  32,  42,  51,  69,  77,  83,  91, 100],
       [  3,  12,  24,  34,  44,  51,  69,  77,  83,  91, 100]])

现在,我想在sliceB的每一行中创建满足条件(例如> = 35)的第一个元素的索引数组,即这些值:

array([[  3,  12,  21,  31,  *41*,  51,  69,  77,  83,  91, 100],
       [  6,  12,  23,  33,  *43*,  51,  69,  77,  83,  91, 100],
       [  8,  12,  27,  *37*,  47,  51,  69,  77,  83,  91, 100],
       [  4,  12,  28,  *38*,  48,  51,  69,  77,  83,  91, 100],
       [  7,  12,  29,  *39*,  49,  51,  69,  77,  83,  91, 100],
       [  9,  12,  22,  32,  *42*,  51,  69,  77,  83,  91, 100],
       [  6,  12,  21,  31,  *41*,  51,  69,  77,  83,  91, 100],
       [  8,  12,  25,  *35*,  45,  51,  69,  77,  83,  91, 100],
       [  5,  12,  26,  *36*,  46,  51,  69,  77,  83,  91, 100],
       [  7,  12,  22,  32,  *42*,  51,  69,  77,  83,  91, 100],
       [  3,  12,  24,  34,  *44*,  51,  69,  77,  83,  91, 100]])

然后使用它在sliceA中创建一个带有相应索引的值数组,即:

array([[   0,   20,   40,   60,   *80*,  100,  120,  140,  160,  180,  200],
       [ 220,  240,  260,  280,  *300*,  320,  340,  360,  380,  400,  420],
       [ 440,  460,  480,  *500*,  520,  540,  560,  580,  600,  620,  640],
       [ 660,  680,  700,  *720*,  740,  760,  780,  800,  820,  840,  860],
       [ 880,  900,  920,  *940*,  960,  980, 1000, 1020, 1040, 1060, 1080],
       [1100, 1120, 1140, 1160, *1180*, 1200, 1220, 1240, 1260, 1280, 1300],
       [1320, 1340, 1360, 1380, *1400*, 1420, 1440, 1460, 1480, 1500, 1520],
       [1540, 1560, 1580, *1600*, 1620, 1640, 1660, 1680, 1700, 1720, 1740],
       [1760, 1780, 1800, *1820*, 1840, 1860, 1880, 1900, 1920, 1940, 1960],
       [1980, 2000, 2020, 2040, *2060*, 2080, 2100, 2120, 2140, 2160, 2180],
       [2200, 2220, 2240, 2260, *2280*, 2300, 2320, 2340, 2360, 2380, 2400]])

我使用以下功能尝试了几个小时: np.amax,np.argmax,np.where,x [x> 34] .min()

但似乎无法找到缺失的链接或组合。

为了提高速度,我想在没有循环的情况下这样做。

2 个答案:

答案 0 :(得分:3)

我现在无法测试它,但它应该非常简单:

idx = np.argmax(sliceB >= 35, axis=1) # index of first occurrence of condition
sliceA[np.arange(sliceA.shape[0]), idx]

答案 1 :(得分:2)

这样的事情应该有效:

#First sort sliceA
tmp =  np.argsort(sliceA,axis=1)           
#Mask all indices that you dont want with values larger then any in the array
tmp[ sliceB<=34 ] = tmp.shape[-1]*2        
#Find the minimum positions
min_pos = tmp.argmin(axis=1)

#Finally take the slice
print sliceA[np.arange(sliceA.shape[0]),min_pos]
[  80  300  500  720  940 1180 1400 1600 1820 2060 2280]