获取NumPy数组中元素的索引

时间:2020-07-28 08:15:55

标签: python numpy pytorch

我有一个Numpy整数数组,其中包含很多重复的元素。

例如:

a = np.random.randint(0,5,20)
a
Out[23]:
array([3, 1, 2, 4, 1, 2, 4, 3, 2, 3, 1, 4, 4, 1, 2, 4, 2, 4, 1, 1])

有两种情况:

  1. 如果一个元素小于4,则获取该元素的所有索引
  2. 如果一个元素大于或等于4,则随机选择四个元素

我用循环解决了这个问题。

ans = np.array([])
num = 4
for i in range(1,5):
    indexes = np.where(a == i)[0] # all indexes of elements equal to i
    index_i = np.random.choice(indexes, num, False) if len(indexes) >=num else indexes
    ans = np.concatenate([ans, index_i])

np.sort(ans)

Out[57]:
array([ 0.,  1.,  2.,  5.,  6.,  7.,  8.,  9., 10., 11., 13., 14., 15.,
       17., 19.])

我可以在Numpy或PyTorch中无循环或更有效地解决此问题吗?

2 个答案:

答案 0 :(得分:1)

使用 Pandas ,您可以非常轻松地做到这一点。

首先将数组转换为 pandasonic 系列

s = pd.Series(a)

然后:

  • 按其值分组。
  • 对每个组应用一个函数,该函数:
    • 对于大小为 4 的组,仅返回该组
    • 对于具有更多成员的组,返回4个元素的随机样本 从他们那里。
  • 删除结果索引的第0 级(在分组过程中添加)。
  • 按(原始)索引排序,以恢复原始顺序(无 删除的元素,现在我们有了它们的原始值 相应的索引)。
  • Numpy 数组的形式返回上述结果的索引

执行此操作的代码是:

s.groupby(s).apply(lambda grp: grp if grp.size <= 4 else grp.sample(4))\
    .reset_index(level=0, drop=True).sort_index().index.values

对于包含以下内容的示例数组:

array([2, 2, 1, 0, 1, 0, 2, 2, 2, 3, 0, 2, 1, 0, 0, 3, 3, 0, 2, 4])

结果是:

array([ 0,  2,  4,  5,  7,  9, 10, 11, 12, 14, 15, 16, 17, 18, 19])

为证明此结果正确,我重复了源数组, 在返回索引的元素下方带有“ x”标记。

array([2, 2, 1, 0, 1, 0, 2, 2, 2, 3, 0, 2, 1, 0, 0, 3, 3, 0, 2, 4])
       x     x     x  x     x     x  x  x  x     x  x  x  x  x  x

答案 1 :(得分:0)

是的,您可以通过以下方式使用NumPy:

a = np.random.randint(0,10,20)
print(a)

num = 4
if str(np.where(a<num)[0].shape) != '(0,)':             # Condition 1
    ans = np.where(a<num)[0]
    print(ans)
if str(np.where(a>=num)[0].shape) != '(0,)':            # Condition 2
    ans = np.random.choice(a[np.where(a>=num)[0]], 4)
    print(ans)

'''Output:
[9 9 8 1 0 7 7 4 6 2 8 2 1 2 9 5 5 1 4 1]
[ 3  4  9 11 12 13 17 19]
[4 9 8 7]
'''

我仅针对您提到的情况做过。可能还有很多其他情况,例如两个条件都成立,或者第二种情况下少于4个数字。