给出了像这样的整数数组
[1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5]
我需要掩盖重复超过N
次的元素。 要澄清:主要目标是检索布尔掩码数组,以后再用于装箱计算。
我想出了一个相当复杂的解决方案
import numpy as np
bins = np.array([1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5])
N = 3
splits = np.split(bins, np.where(np.diff(bins) != 0)[0]+1)
mask = []
for s in splits:
if s.shape[0] <= N:
mask.append(np.ones(s.shape[0]).astype(np.bool_))
else:
mask.append(np.append(np.ones(N), np.zeros(s.shape[0]-N)).astype(np.bool_))
mask = np.concatenate(mask)
例如给予
bins[mask]
Out[90]: array([1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5])
有更好的方法吗?
编辑,#2
非常感谢您的回答!这是MSeifert基准测试图的精简版。感谢您将我指向simple_benchmark
。仅显示4个最快的选项:
结论
由Florian H提出并由Paul Panzer修改的想法似乎是解决此问题的好方法,因为它很简单,而且仅numpy
。但是,如果您可以使用numba
,则MSeifert's solution的性能要好于其他。
我选择接受MSeifert的答案作为解决方案,因为它是更笼统的答案:它可以正确地处理具有(非唯一)连续重复元素块的任意数组。如果numba
不可行,Divakar's answer也值得一看!
答案 0 :(得分:7)
免责声明:这只是@FlorianH想法的合理实现:
def f(a,N):
mask = np.empty(a.size,bool)
mask[:N] = True
np.not_equal(a[N:],a[:-N],out=mask[N:])
return mask
对于更大的数组,这有很大的不同:
a = np.arange(1000).repeat(np.random.randint(0,10,1000))
N = 3
print(timeit(lambda:f(a,N),number=1000)*1000,"us")
# 5.443050000394578 us
# compare to
print(timeit(lambda:[True for _ in range(N)] + list(bins[:-N] != bins[N:]),number=1000)*1000,"us")
# 76.18969900067896 us
答案 1 :(得分:4)
方法1:这是一种矢量化方法-
from scipy.ndimage.morphology import binary_dilation
def keep_N_per_group(a, N):
k = np.ones(N,dtype=bool)
m = np.r_[True,a[:-1]!=a[1:]]
return a[binary_dilation(m,k,origin=-(N//2))]
样品运行-
In [42]: a
Out[42]: array([1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5])
In [43]: keep_N_per_group(a, N=3)
Out[43]: array([1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5])
方法2:更紧凑的版本-
def keep_N_per_group_v2(a, N):
k = np.ones(N,dtype=bool)
return a[binary_dilation(np.ediff1d(a,to_begin=a[0])!=0,k,origin=-(N//2))]
方法#3::使用分组计数和np.repeat
(虽然不会给我们掩码)-
def keep_N_per_group_v3(a, N):
m = np.r_[True,a[:-1]!=a[1:],True]
idx = np.flatnonzero(m)
c = np.diff(idx)
return np.repeat(a[idx[:-1]],np.minimum(c,N))
方法4::使用view-based
方法-
from skimage.util import view_as_windows
def keep_N_per_group_v4(a, N):
m = np.r_[True,a[:-1]!=a[1:]]
w = view_as_windows(m,N)
idx = np.flatnonzero(m)
v = idx<len(w)
w[idx[v]] = 1
if v.all()==0:
m[idx[v.argmin()]:] = 1
return a[m]
方法5::使用view-based
方法,而没有来自flatnonzero
的索引-
def keep_N_per_group_v5(a, N):
m = np.r_[True,a[:-1]!=a[1:]]
w = view_as_windows(m,N)
last_idx = len(a)-m[::-1].argmax()-1
w[m[:-N+1]] = 1
m[last_idx:last_idx+N] = 1
return a[m]
答案 2 :(得分:4)
我想提出一个使用numba的解决方案,该解决方案应该很容易理解。我假设您要“屏蔽”连续的重复项:
import numpy as np
import numba as nb
@nb.njit
def mask_more_n(arr, n):
mask = np.ones(arr.shape, np.bool_)
current = arr[0]
count = 0
for idx, item in enumerate(arr):
if item == current:
count += 1
else:
current = item
count = 1
mask[idx] = count <= n
return mask
例如:
>>> bins = np.array([1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5])
>>> bins[mask_more_n(bins, 3)]
array([1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5])
>>> bins[mask_more_n(bins, 2)]
array([1, 1, 2, 2, 3, 3, 4, 4, 5, 5])
使用simple_benchmark
-但是我还没有包括所有方法。这是对数-对数比例:
似乎numba解决方案无法击败Paul Panzer的解决方案,对于大型阵列而言,解决方案似乎要快一点(并且不需要其他依赖项)。
但是,两者似乎都胜过其他解决方案,但是它们确实返回掩码而不是“过滤”数组。
import numpy as np
import numba as nb
from simple_benchmark import BenchmarkBuilder, MultiArgument
b = BenchmarkBuilder()
bins = np.array([1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5])
@nb.njit
def mask_more_n(arr, n):
mask = np.ones(arr.shape, np.bool_)
current = arr[0]
count = 0
for idx, item in enumerate(arr):
if item == current:
count += 1
else:
current = item
count = 1
mask[idx] = count <= n
return mask
@b.add_function(warmups=True)
def MSeifert(arr, n):
return mask_more_n(arr, n)
from scipy.ndimage.morphology import binary_dilation
@b.add_function()
def Divakar_1(a, N):
k = np.ones(N,dtype=bool)
m = np.r_[True,a[:-1]!=a[1:]]
return a[binary_dilation(m,k,origin=-(N//2))]
@b.add_function()
def Divakar_2(a, N):
k = np.ones(N,dtype=bool)
return a[binary_dilation(np.ediff1d(a,to_begin=a[0])!=0,k,origin=-(N//2))]
@b.add_function()
def Divakar_3(a, N):
m = np.r_[True,a[:-1]!=a[1:],True]
idx = np.flatnonzero(m)
c = np.diff(idx)
return np.repeat(a[idx[:-1]],np.minimum(c,N))
from skimage.util import view_as_windows
@b.add_function()
def Divakar_4(a, N):
m = np.r_[True,a[:-1]!=a[1:]]
w = view_as_windows(m,N)
idx = np.flatnonzero(m)
v = idx<len(w)
w[idx[v]] = 1
if v.all()==0:
m[idx[v.argmin()]:] = 1
return a[m]
@b.add_function()
def Divakar_5(a, N):
m = np.r_[True,a[:-1]!=a[1:]]
w = view_as_windows(m,N)
last_idx = len(a)-m[::-1].argmax()-1
w[m[:-N+1]] = 1
m[last_idx:last_idx+N] = 1
return a[m]
@b.add_function()
def PaulPanzer(a,N):
mask = np.empty(a.size,bool)
mask[:N] = True
np.not_equal(a[N:],a[:-N],out=mask[N:])
return mask
import random
@b.add_arguments('array size')
def argument_provider():
for exp in range(2, 20):
size = 2**exp
yield size, MultiArgument([np.array([random.randint(0, 5) for _ in range(size)]), 3])
r = b.run()
import matplotlib.pyplot as plt
plt.figure(figsize=[10, 8])
r.plot()
答案 3 :(得分:2)
您可以通过建立索引来实现。对于任何N,代码将为:
N = 3
bins = np.array([1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5,6])
mask = [True for _ in range(N)] + list(bins[:-N] != bins[N:])
bins[mask]
输出:
array([1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6]
答案 4 :(得分:1)
您可以使用while循环来检查数组元素N向后定位是否等于当前位置。请注意,此解决方案假定数组是有序的。
> dplyr::vars(ends_with("color"))
<list_of<quosure>>
[[1]]
<quosure>
expr: ^ends_with("color")
env: global
答案 5 :(得分:1)
更好的方法是使用numpy
的{{1}}函数。您将在数组中获得唯一的条目,以及它们出现频率的计数:
unique()
输出:
bins = np.array([1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5])
N = 3
unique, index,count = np.unique(bins, return_index=True, return_counts=True)
mask = np.full(bins.shape, True, dtype=bool)
for i,c in zip(index,count):
if c>N:
mask[i+N:i+c] = False
bins[mask]
答案 6 :(得分:0)
您可以使用numpy.unique
。变量final_mask
可用于从数组bins
中提取traget元素。
import numpy as np
bins = np.array([1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5])
repeat_max = 3
unique, counts = np.unique(bins, return_counts=True)
mod_counts = np.array([x if x<=repeat_max else repeat_max for x in counts])
mask = np.arange(bins.size)
#final_values = np.hstack([bins[bins==value][:count] for value, count in zip(unique, mod_counts)])
final_mask = np.hstack([mask[bins==value][:count] for value, count in zip(unique, mod_counts)])
bins[final_mask]
输出:
array([1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5])
答案 7 :(得分:0)
您可以使用 grouby 对长度大于 N 的常见元素和过滤器列表进行分组。
import numpy as np
from itertools import groupby, chain
def ifElse(condition, exec1, exec2):
if condition : return exec1
else : return exec2
def solve(bins, N = None):
xss = groupby(bins)
xss = map(lambda xs : list(xs[1]), xss)
xss = map(lambda xs : ifElse(len(xs) > N, xs[:N], xs), xss)
xs = chain.from_iterable(xss)
return list(xs)
bins = np.array([1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5])
solve(bins, N = 3)