更改numpy数组中最接近值的n个元素的符号

时间:2019-04-03 10:37:32

标签: python numpy

我希望将最接近的numpy数组中的n个元素的符号更改为某个值,但不能小于此值。也就是说,元素必须等于或大于该值。是否有任何快速的Numpy方法可以对大型数组进行有效处理?

我现在拥有的代码采用的n值更高或相等,但不是最接近的值,这是“还可以”,但对我的结果而言并不理想。

def update(arr, n, value):
    updated = 0
    i = 0
    while updated < n:
        if arr[i] >= value: # just a random value above "value"
            arr[i] = -arr[i]
            updated +=1
        i += 1

arr = np.array([9, 8, 2, -4, 3, 4])
n = 3
value = 2
update(arr, n, value)

给我

arr = np.array([-9, -8, -2, -4, 3, 4])

当我反而想要

arr = np.array([9, 8, -2, -4, -3, -4])

3 个答案:

答案 0 :(得分:0)

我没有就地更新数组,但是我会做类似的事情:

def update(arr, n, value):
    arr_copy = arr.copy()
    diffs = arr - value
    absolute_diffs = np.abs(diffs)
    update_indeces = np.argpartition(absolute_diffs, n)[:n]
    arr_copy[update_indeces] *= -1
    return arr_copy

答案 1 :(得分:0)

您可以使用argpartition

arr = np.random.random(20)
value = 0.5
n = 4

nl = np.count_nonzero(arr<value)
closest = np.argpartition(arr, (nl, nl+n-1))[nl:nl+n]
arr[closest] = -arr[closest]
arr
# array([ 0.33697627,  0.42607914, -0.63703314, -0.57517234,  0.82674228,
#        -0.52929285,  0.64776714,  0.25609886,  0.24681445,  0.2486823 ,
#         0.76740245,  0.02368603,  0.21498096, -0.51033841,  0.19901665,
#         0.30939207,  0.69036139,  0.83178506,  0.97243443,  0.47620492])

答案 2 :(得分:0)

这应该有效:

def flip_some(a, n, value):
    more_than = (a >= value)
    first_n_elements = (a < np.sort(a[more_than])[n])
    return np.where(more_than & first_n_elements, -a, a)

print(flip_some(np.array([9, 8, 2, -4, 3, 4]), 3, 2))
print(flip_some(np.arange(10), 2, 5))

输出:

[ 9 -8  -2 -4 -3 -4]
[ 0  1  2  3  4 -5 -6 -7  8  9]