我希望将最接近的numpy数组中的n个元素的符号更改为某个值,但不能小于此值。也就是说,元素必须等于或大于该值。是否有任何快速的Numpy方法可以对大型数组进行有效处理?
我现在拥有的代码采用的n值更高或相等,但不是最接近的值,这是“还可以”,但对我的结果而言并不理想。
def update(arr, n, value):
updated = 0
i = 0
while updated < n:
if arr[i] >= value: # just a random value above "value"
arr[i] = -arr[i]
updated +=1
i += 1
arr = np.array([9, 8, 2, -4, 3, 4])
n = 3
value = 2
update(arr, n, value)
给我
arr = np.array([-9, -8, -2, -4, 3, 4])
当我反而想要
arr = np.array([9, 8, -2, -4, -3, -4])
答案 0 :(得分:0)
我没有就地更新数组,但是我会做类似的事情:
def update(arr, n, value):
arr_copy = arr.copy()
diffs = arr - value
absolute_diffs = np.abs(diffs)
update_indeces = np.argpartition(absolute_diffs, n)[:n]
arr_copy[update_indeces] *= -1
return arr_copy
答案 1 :(得分:0)
您可以使用argpartition
:
arr = np.random.random(20)
value = 0.5
n = 4
nl = np.count_nonzero(arr<value)
closest = np.argpartition(arr, (nl, nl+n-1))[nl:nl+n]
arr[closest] = -arr[closest]
arr
# array([ 0.33697627, 0.42607914, -0.63703314, -0.57517234, 0.82674228,
# -0.52929285, 0.64776714, 0.25609886, 0.24681445, 0.2486823 ,
# 0.76740245, 0.02368603, 0.21498096, -0.51033841, 0.19901665,
# 0.30939207, 0.69036139, 0.83178506, 0.97243443, 0.47620492])
答案 2 :(得分:0)
这应该有效:
def flip_some(a, n, value):
more_than = (a >= value)
first_n_elements = (a < np.sort(a[more_than])[n])
return np.where(more_than & first_n_elements, -a, a)
print(flip_some(np.array([9, 8, 2, -4, 3, 4]), 3, 2))
print(flip_some(np.arange(10), 2, 5))
输出:
[ 9 -8 -2 -4 -3 -4]
[ 0 1 2 3 4 -5 -6 -7 8 9]