我有以下numpy数组y_train
:
y_train =
2
2
1
0
1
1
2
0
0
我需要随机选择n
(n = 2)行的索引,如下所示:
n=2
n indices of rows where y=0
n indices of rows where y=1
n indices of rows where y=2
我使用以下代码:
n=2
idx = [y_train[np.random.choice(np.where(y_train==np.unique(y_train)[I])[0],n)].index.tolist() \
for i in np.unique(y_train).astype(int)]
我的真实数组y_train
中的错误:
KeyError: '[70798 63260 64755 ... 7012 65605 45218] not in index'
答案 0 :(得分:2)
如果您的预期输出是y_train
中每个唯一值的随机选择索引列表:
idx = [np.random.choice(np.where(y_train == i)[0], size=2, \
replace=False) for i in np.unique(y_train)]
输出:
[array([7, 8]), array([5, 4]), array([1, 0])]
如果要将数组展平为单个数组:
idx = np.array(idx).flatten()
输出:
array([7, 8, 5, 4, 1, 0])
答案 1 :(得分:1)
获取所需索引的另一种解决方案是使用nonzero
并简单地循环访问range(n+1)
y_train = np.array([2,2,1,0,1,1,2,0,0])
indices = [np.random.choice((y_train==i).nonzero()[0], 2, replace=False) for i in range(n+1)]
print (indices)
# [array([7, 3]), array([5, 4]), array([0, 1])]
print (np.array(indices).ravel())
# [7 3 5 4 0 1]