numpy中跨行的多维索引(避免广播)

时间:2018-09-21 16:53:56

标签: numpy

我正在尝试索引另一个二维numpy ndarray的二维numpy ndarray。

我想要的效果是为索引数组中的每一行选择要索引的数组的相应行中的元素。也就是说,我希望第i行或我的索引数组为要索引的数组的第i行的元素编制索引(但没有其他行)。

但是,目前看来,当我尝试对数组进行索引时,正在对要索引的数组的每一行广播索引数组。

我正在使用的数组是(3,1001)数组和(3,5)数组。我正在尝试用(3,5)数组索引(3,1001)数组,并从要索引的数组的每个1001维行中选择5个元素。

例如,这是我想要的行为:

predictions_val[0][top_5[0]]
array([ 0.00222665,  0.00606673,  0.03681596,  0.85334235,  0.01018796], dtype=float32)

predictions_val[1][top_5[1]]
array([ 0.00106781,  0.00407206,  0.026693  ,  0.90732217,  0.0234713 ], dtype=float32)

predictions_val[2][top_5[2]]
array([ 0.00112946,  0.0016792 ,  0.06700196,  0.00367496,  0.87981129], dtype=float32)

这是我尝试同时建立索引时得到的行为:

predictions_val[:,top_5]
array([[[  2.22665281e-03,   6.06672745e-03,   3.68159562e-02,
           8.53342354e-01,   1.01879649e-02],
        [  5.12826555e-05,   8.53342354e-01,   1.41255208e-03,
           2.77817919e-04,   1.01879649e-02],
        [  2.17145571e-04,   2.77817919e-04,   8.53342354e-01,
           1.41255208e-03,   1.01879649e-02]],

       [[  5.50073055e-05,   8.74355683e-05,   2.71841218e-05,
           4.07205941e-03,   2.34712958e-02],
        [  1.06781046e-03,   4.07205941e-03,   2.66929977e-02,
           9.07322168e-01,   2.34712958e-02],
        [  5.84539608e-04,   9.07322168e-01,   4.07205941e-03,
           2.66929977e-02,   2.34712958e-02]],

       [[  1.05086729e-04,   2.83752568e-04,   7.68712547e-04,
           6.70019612e-02,   8.79811287e-01],
        [  4.69864433e-04,   6.70019612e-02,   3.67495860e-03,
           1.67920033e-03,   8.79811287e-01],
        [  1.12945912e-03,   1.67920033e-03,   6.70019612e-02,
           3.67495860e-03,   8.79811287e-01]]], dtype=float32)

我想要的每一行都存在于返回的数组中,但似乎top_5数组正在整个行中广播。

2 个答案:

答案 0 :(得分:1)

您必须正确索引数据。 np.indices可以为您提供帮助:

pred=rand(3,1001)
top=randint(0,1001,(3,5))

I,J=indices(top.shape)
res=pred[I,top]

然后res[i]的{​​{1}}就是您想要的。

答案 1 :(得分:0)

您应该能够执行以下类似操作;

X = ... # your data, shape (3,1001)
idx = ... # the wanted indices, shape (3,5)

# reshape idx to (2,idx.shape[0]*idx.shape[1])
idx = np.array([[i,index] for i in range(X.shape[0]) for index in idx[i]]).tranpose()

Wanted = X[idx[0],idx[1]]

这将创建一个数组idx,其中第一行是X中的所需行,第二行是X中的所需列。

相关问题