如何使用多维索引数组来获取多维数组中的行?
当我尝试使用索引数组来拉动行的变体时,它会抛出有关形状的错误。
import numpy as np
neibs = np.array([[ 1.7117279052734375000000e+01, 1.7255817413330078125000e+01,
1.7325582504272460937500e+01, 1.7325582504272460937500e+01,
1.7255817413330078125000e+01, 2.2046510696411132812500e+01,
2.2232553482055664062500e+01, 2.2325582504272460937500e+01,
2.2325582504272460937500e+01, 2.2232553482055664062500e+01,
2.4651163101196289062500e+01, 2.4883720397949218750000e+01,
2.5000000000000000000000e+01, 2.5000000000000000000000e+01,
2.4883720397949218750000e+01, 2.4651163101196289062500e+01,
2.4883720397949218750000e+01, 2.5000000000000000000000e+01,
2.5000000000000000000000e+01, 2.4883720397949218750000e+01,
2.4651163101196289062500e+01, 2.4883720397949218750000e+01,
2.5000000000000000000000e+01, 2.5000000000000000000000e+01,
2.4883720397949218750000e+01],
[ 1.8334333419799804687500e+01, 1.8333333969116210937500e+01,
1.8333333969116210937500e+01, 1.6000000000000000000000e+01,
1.2333333015441894531250e+01, 2.3333333969116210937500e+01,
2.3333333969116210937500e+01, 2.3333333969116210937500e+01,
2.0333333969116210937500e+01, 1.5666666984558105468750e+01,
2.5000000000000000000000e+01, 2.5000000000000000000000e+01,
2.5000000000000000000000e+01, 2.1666666030883789062500e+01,
1.6666666030883789062500e+01, 2.5000000000000000000000e+01,
2.5000000000000000000000e+01, 2.5000000000000000000000e+01,
2.1666666030883789062500e+01, 1.6666666030883789062500e+01,
2.5000000000000000000000e+01, 2.5000000000000000000000e+01,
2.5000000000000000000000e+01, 2.1666666030883789062500e+01,
1.6666666030883789062500e+01],
[ 1.6334333419799804687500e+01, 2.1000000000000000000000e+01,
2.3333333969116210937500e+01, 2.3333333969116210937500e+01,
2.3333333969116210937500e+01, 1.7500000000000000000000e+01,
2.2500000000000000000000e+01, 2.5000000000000000000000e+01,
2.5000000000000000000000e+01, 2.5000000000000000000000e+01,
1.7500000000000000000000e+01, 2.2500000000000000000000e+01,
2.5000000000000000000000e+01, 2.5000000000000000000000e+01,
2.5000000000000000000000e+01, 1.7500000000000000000000e+01,
2.2500000000000000000000e+01, 2.5000000000000000000000e+01,
2.5000000000000000000000e+01, 2.5000000000000000000000e+01,
1.7500000000000000000000e+01, 2.2500000000000000000000e+01,
2.5000000000000000000000e+01, 2.5000000000000000000000e+01,
2.5000000000000000000000e+01],
[ 2.5001003265380859375000e+01, 2.5000000000000000000000e+01,
2.5000000000000000000000e+01, 2.5000000000000000000000e+01,
2.5000000000000000000000e+01, 2.5000000000000000000000e+01,
2.5000000000000000000000e+01, 2.5000000000000000000000e+01,
2.5000000000000000000000e+01, 2.5000000000000000000000e+01,
2.5000000000000000000000e+01, 2.5000000000000000000000e+01,
2.5000000000000000000000e+01, 2.5000000000000000000000e+01,
2.5000000000000000000000e+01, 2.5000000000000000000000e+01,
2.5000000000000000000000e+01, 2.5000000000000000000000e+01,
2.5000000000000000000000e+01, 2.5000000000000000000000e+01,
2.5000000000000000000000e+01, 2.5000000000000000000000e+01,
2.5000000000000000000000e+01, 2.5000000000000000000000e+01,
2.5000000000000000000000e+01],
[ 2.5000999450683593750000e+01, 2.5000000000000000000000e+01,
2.5000000000000000000000e+01, 2.2500000000000000000000e+01,
1.7500000000000000000000e+01, 2.5000000000000000000000e+01,
2.5000000000000000000000e+01, 2.5000000000000000000000e+01,
2.2500000000000000000000e+01, 1.7500000000000000000000e+01,
2.5000000000000000000000e+01, 2.5000000000000000000000e+01,
2.5000000000000000000000e+01, 2.2500000000000000000000e+01,
1.7500000000000000000000e+01, 2.5000000000000000000000e+01,
2.5000000000000000000000e+01, 2.5000000000000000000000e+01,
2.2500000000000000000000e+01, 1.7500000000000000000000e+01,
2.5000000000000000000000e+01, 2.5000000000000000000000e+01,
2.5000000000000000000000e+01, 2.2500000000000000000000e+01,
1.7500000000000000000000e+01],
[ 2.5000999450683593750000e+01, 2.5000000000000000000000e+01,
2.5000000000000000000000e+01, 2.5000000000000000000000e+01,
2.5000000000000000000000e+01, 2.5000000000000000000000e+01,
2.5000000000000000000000e+01, 2.5000000000000000000000e+01,
2.5000000000000000000000e+01, 2.5000000000000000000000e+01,
2.5000000000000000000000e+01, 2.5000000000000000000000e+01,
2.5000000000000000000000e+01, 2.5000000000000000000000e+01,
2.5000000000000000000000e+01, 2.5000000000000000000000e+01,
2.5000000000000000000000e+01, 2.5000000000000000000000e+01,
2.5000000000000000000000e+01, 2.5000000000000000000000e+01,
2.5000000000000000000000e+01, 2.5000000000000000000000e+01,
2.5000000000000000000000e+01, 2.5000000000000000000000e+01,
2.5000000000000000000000e+01]])
neibs = np.reshape(neibs, (2, 1, 3, 25))
print('data shape')
print(np.shape(neibs))
activity = np.sum(neibs, axis=3, keepdims=True)
print('activity')
print(np.shape(activity))
print(activity)
#Sort the neibs for each input, channel
indices = np.argsort(-activity, axis=2)
print('sorted indices')
print(indices)
print(np.shape(indices))
#Get the top 2 indices for each input, channel
top_indices = indices[:, :, :2]
print('top indices')
print(top_indices)
print(np.shape(top_indices))
#Works as anticipated until this point
#Trying to get only the top 2 neibs for each input, channel
#The output desired is shape (2, 1, 2, 25)
#Tried below and multiple variations but to no avail
top_neibs = neibs[top_indices]
print('top neibs')
print(top_neibs)
print(np.shape(top_neibs))