如何在numpy数组中选择行的索引?

时间:2019-02-15 21:24:35

标签: python numpy

我有以下numpy数组y_train

y_train =

2
2
1
0
1
1 
2
0
0

我需要随机选择n(n = 2)行的索引,如下所示:

n=2 
n indices of rows where y=0 
n indices of rows where y=1 
n indices of rows where y=2

我使用以下代码:

n=2
idx = [y_train[np.random.choice(np.where(y_train==np.unique(y_train)[I])[0],n)].index.tolist() \
 for i in np.unique(y_train).astype(int)]

我的真实数组y_train中的错误:

KeyError: '[70798 63260 64755 ...  7012 65605 45218] not in index'

2 个答案:

答案 0 :(得分:2)

如果您的预期输出是y_train中每个唯一值的随机选择索引列表:

idx = [np.random.choice(np.where(y_train == i)[0], size=2, \
       replace=False) for i in np.unique(y_train)]

输出:

[array([7, 8]), array([5, 4]), array([1, 0])]

如果要将数组展平为单个数组:

idx = np.array(idx).flatten()

输出:

array([7, 8, 5, 4, 1, 0])

答案 1 :(得分:1)

获取所需索引的另一种解决方案是使用nonzero并简单地循环访问range(n+1)

y_train = np.array([2,2,1,0,1,1,2,0,0])

indices = [np.random.choice((y_train==i).nonzero()[0], 2, replace=False) for i in range(n+1)]
print (indices)
# [array([7, 3]), array([5, 4]), array([0, 1])]

print (np.array(indices).ravel())
# [7 3 5 4 0 1]