据我所知,Numpy可以根据numpy.where
寻找的值生成数组的索引。我的问题:是否有一个函数可以生成索引给定多个值。
例如,使用此数组
a = np.array([1.,0.,0.,0.,1.,1.,0.,0.,0.,0.,...,1.,1.])
如果我只能指定4个零并且函数可以告诉它的索引那么我会立即用另一个值替换它们。我有一个可以识别模式的功能,但效率不高。任何指针都会非常有用
答案 0 :(得分:1)
好像我每周给出一次这样的回答。最快且最节省内存的方法是对void
次观看
as_strided
次观看
def rolling_window(a, window): #based on @senderle's answer: https://stackoverflow.com/q/7100242/2901002
shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
strides = a.strides + (a.strides[-1],)
c = np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)
return c
def vview(a): #based on @jaime's answer: https://stackoverflow.com/a/16973510/4427777
return np.ascontiguousarray(a).view(np.dtype((np.void, a.dtype.itemsize * a.shape[1])))
def pattwhere_void(pattern, a): # Using @PaulPanzer's template form above
k, n = map(len, (pattern, a))
pattern = np.atleast_2d(pattern)
a = np.asanyarray(a)
if k>n:
return np.empty([0], int)
return np.flatnonzero(np.in1d(vview(rolling_window(a, k)), vview(pattern)))
答案 1 :(得分:0)
这应该有效并且效率很高:
a = np.array([1,0,0,0,1,1,0,0,0,0,1,1,0,0,0,0])
a0 = a==0
np.where(a0[:-3] & a0[1:-2] & a0[2:-1] & a0[3:])
(array([6, 12], dtype=int64),) # indices of first of 4 consecutive 0's
答案 2 :(得分:0)
以下是三种不同的方法。
方法1:线性相关
import numpy as np
def pattwhere_corr(pattern, a):
pattern, a = map(np.asanyarray, (pattern, a))
k = len(pattern)
if k>len(a):
return np.empty([0], int)
n = np.dot(pattern, pattern)
a2 = a*a
slf = a2[:k].sum() + np.r_[0, np.cumsum(a2[k:] - a2[:-k])]
crs = np.correlate(a, pattern, 'valid')
return np.flatnonzero(np.isclose(slf, n) & np.isclose(crs, n))
方法2:逐元素地减少
def pattwhere_sequ(pattern, a):
pattern, a = map(np.asanyarray, (pattern, a))
k = len(pattern)
if k>len(a):
return np.empty([0], int)
hits = np.flatnonzero(a == pattern[-1])
for p in pattern[-2::-1]:
hits -= 1
hits = hits[a[hits] == p]
return hits
方法3:蛮力
def pattwhere_direct(pattern, a):
pattern, a = map(np.asanyarray, (pattern, a))
k, n = map(len, (pattern, a))
if k>n:
return np.empty([0], int)
astr = np.lib.stride_tricks.as_strided(a, (n-k+1, k), 2*a.strides)
return np.flatnonzero((astr == pattern).all(axis=1))
一些测试:
k, n, p = 4, 100, 5
pattern, a = np.random.randint(0, p, (k,)), np.random.randint(0, p, (n,))
print('results consistent:',
np.all(pattwhere_sequ(pattern, a) == pattwhere_corr(pattern, a)) &
np.all(pattwhere_sequ(pattern, a) == pattwhere_direct(pattern, a)))
from timeit import timeit
for k, n, p in [(4, 100, 5), (10, 1000000, 4), (1000, 10000, 3)]:
print('k, n, p = ', k, n, p)
pattern, a = np.random.randint(0, p, (k,)), np.random.randint(0, p, (n,))
glb = {'pattern': pattern, 'a': a}
kwds = {'number': 1000, 'globals': glb}
for name, glb['func'] in list(locals().items()):
if not name.startswith('pattwhere'):
continue
print(name.replace('pattwhere_', '').ljust(8), '{:8.6f} ms'.format(
timeit('func(pattern, a)', **kwds)))
示例输出。请注意,这些基准测试是在模式以随机频率发生的情况下进行的。如果结果过高,结果可能会发生变化。
results consistent: True
k, n, p = 4 100 5
corr 0.090752 ms
sequ 0.015759 ms
direct 0.023338 ms
k, n, p = 10 1000000 4
corr 39.290270 ms
sequ 8.182161 ms
direct 34.399724 ms
k, n, p = 1000 10000 3
corr 6.319400 ms
sequ 2.225807 ms
direct 9.001689 ms
答案 3 :(得分:0)
试试这段代码!
这将返回numpy数组的连续零位置范围。 因此,您可以使用这些范围之间的任何整数值替换零。
代码:
import numpy as np
from itertools import groupby
a = np.array([1.,0.,0.,0.,1.,1.,0.,0.,0.,0.,1.,1.])
b = range(len(a))
for group in groupby(iter(b), lambda x: a[x]):
if group[0]==0:
lis=list(group[1])
print([min(lis),max(lis)])
输出
[1, 3]
[6, 9]
答案 4 :(得分:0)
不是最快的方法,但您可以使用scipy
为n维数组和n维模式生成稳健的解决方案。
import scipy
from scipy.ndimage import label
#=================
# Helper functions
#=================
# Nested list to nested tuple helper function
# from https://stackoverflow.com/questions/27049998/convert-a-mixed-nested-list-to-a-nested-tuple
def to_tuple(L):
return tuple(to_tuple(i) if isinstance(i, list) else i for i in L)
# Helper function to convert array to set of tuples
def arr2set(arr):
return set(to_tuple(arr.tolist()))
#===============
# Main algorithm
#===============
# First pass: filter for exact matches
a1 = scipy.zeros_like(a, dtype=bool)
freq_dict = {}
notnan = ~scipy.isnan(pattern)
for i in scipy.unique(pattern[notnan]):
a1 = a1 + (a == i)
freq_dict[i] = (pattern == i).sum()
# Minimise amount of pattern checking by choosing least frequent occurrence
check_val = freq_dict.keys()[scipy.argmin(freq_dict.values())]
# Get set of indices of pattern
pattern_inds = scipy.transpose(scipy.nonzero(scipy.ones_like(pattern)*notnan))
check_ind = scipy.transpose(scipy.nonzero(pattern == check_val))[0]
pattern_inds = pattern_inds - check_ind
pattern_inds_set = arr2set(pattern_inds)
# Label different regions found in first pass which may contains pattern
label_arr, n = label(a1)
found_inds_list = []
pattern_size = len(pattern_inds)
for i in range(1, n+1):
arr_inds = scipy.transpose(scipy.nonzero(label_arr == i))
bbox_inds = [ind for ind in arr_inds if a[tuple(ind)] == check_val]
for ind in bbox_inds:
check_inds_set = arr2set(arr_inds - ind)
if len(pattern_inds_set - check_inds_set) == 0:
found_inds_list.append(tuple(scipy.transpose(pattern_inds + ind)))
# Replace values
for inds in found_inds_list:
a[inds] = replace_value
为4D案例生成随机测试数组,模式和最终替换值
replace_value = scipy.random.rand() # Final value that you want to replace everything with
nan = scipy.nan # Use this for places in the rectangular pattern array that you don't care about checking
# Generate random data
a = scipy.random.random([12,12,12,12])*12
pattern = scipy.random.random([3,3,3,3])*12
# Put the pattern in random places
for i in range(4):
j1, j2, j3, j4 = scipy.random.choice(xrange(10), 4, replace=True)
a[j1:j1+3, j2:j2+3, j3:j3+3, j4:j4+3] = pattern
a_org = scipy.copy(a)
# Randomly insert nans in the pattern
for i in range(20):
j1, j2, j3, j4 = scipy.random.choice(xrange(3), 4, replace=True)
pattern[j1, j2, j3, j4] = nan
运行主算法后......
>>> print found_inds_list[-1]
(array([ 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11,
11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11], dtype=int64), array([1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 1, 1,
1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 1,
1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3],
dtype=int64), array([5, 5, 5, 6, 6, 6, 7, 7, 5, 5, 6, 6, 7, 7, 7, 5, 5, 6, 6, 7, 5, 5,
5, 6, 6, 7, 7, 5, 5, 6, 7, 7, 7, 5, 5, 5, 6, 6, 6, 7, 7, 7, 5, 5,
5, 6, 6, 7, 5, 5, 5, 6, 6, 6, 7, 5, 5, 6, 6, 6, 7, 7, 7],
dtype=int64), array([1, 2, 3, 1, 2, 3, 1, 3, 1, 3, 1, 3, 1, 2, 3, 2, 3, 1, 2, 2, 1, 2,
3, 1, 2, 1, 3, 1, 2, 2, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2,
3, 1, 3, 1, 1, 2, 3, 1, 2, 3, 2, 1, 3, 1, 2, 3, 1, 2, 3],
dtype=int64))
>>>
>>> replace_value # Display value that's going to be replaced
0.9263912485289564
>>>
>>> print a_org[9:12, 1:4, 5:8, 1:4] # Display original rectangular window of replacement
[[[[ 9.68507479 1.77585089 5.06069382]
[10.63768984 11.41148096 1.13120712]
[ 6.83684611 2.46838238 11.40490158]]
[[ 9.17344668 11.21669704 7.60737639]
[ 3.14870787 6.22857282 5.61295454]
[ 4.32709261 8.00493326 9.96124294]]
[[ 4.16785078 10.66054365 2.95677408]
[11.53789218 2.70725911 11.98647139]
[ 5.00346525 4.75230895 4.05213149]]]
[[[11.23856096 8.45979355 7.53268864]
[ 6.14703327 11.90052117 5.48127994]
[ 2.16777734 10.27373562 7.75420214]]
[[10.04726853 11.44895046 7.78071007]
[ 0.79030038 3.69735083 1.51921116]
[11.29782542 2.58494314 9.8714708 ]]
[[ 7.9356587 1.48053473 9.71362122]
[ 5.11866341 3.43895455 6.86491947]
[ 8.33774813 5.66923131 2.27884056]]]
[[[ 0.75091443 2.02917445 5.68207987]
[ 4.58299978 7.14960394 9.13853129]
[10.60912932 4.52190424 0.6557605 ]]
[[ 0.54393627 8.02341744 11.69489975]
[ 9.09878676 10.60836714 2.41188805]
[ 9.13098333 6.12284334 8.9349382 ]]
[[ 5.84489355 10.19848245 1.65080169]
[ 2.75161562 1.05154552 0.17804374]
[ 3.3166642 10.74081484 5.13601563]]]]
>>>
>>> print a[9:12, 1:4, 5:8, 1:4] # Same window in the replaced array
[[[[ 0.92639125 0.92639125 0.92639125]
[ 0.92639125 0.92639125 0.92639125]
[ 0.92639125 2.46838238 0.92639125]]
[[ 0.92639125 11.21669704 0.92639125]
[ 0.92639125 6.22857282 0.92639125]
[ 0.92639125 0.92639125 0.92639125]]
[[ 4.16785078 0.92639125 0.92639125]
[ 0.92639125 0.92639125 11.98647139]
[ 5.00346525 0.92639125 4.05213149]]]
[[[ 0.92639125 0.92639125 0.92639125]
[ 0.92639125 0.92639125 5.48127994]
[ 0.92639125 10.27373562 0.92639125]]
[[ 0.92639125 0.92639125 7.78071007]
[ 0.79030038 0.92639125 1.51921116]
[ 0.92639125 0.92639125 0.92639125]]
[[ 0.92639125 0.92639125 0.92639125]
[ 0.92639125 0.92639125 0.92639125]
[ 0.92639125 0.92639125 0.92639125]]]
[[[ 0.92639125 0.92639125 0.92639125]
[ 0.92639125 7.14960394 0.92639125]
[ 0.92639125 4.52190424 0.6557605 ]]
[[ 0.92639125 0.92639125 0.92639125]
[ 0.92639125 0.92639125 0.92639125]
[ 9.13098333 0.92639125 8.9349382 ]]
[[ 0.92639125 10.19848245 0.92639125]
[ 0.92639125 0.92639125 0.92639125]
[ 0.92639125 0.92639125 0.92639125]]]]
>>>
>>> print pattern # The pattern that was matched and replaced
[[[[ 9.68507479 1.77585089 5.06069382]
[10.63768984 11.41148096 1.13120712]
[ 6.83684611 nan 11.40490158]]
[[ 9.17344668 nan 7.60737639]
[ 3.14870787 nan 5.61295454]
[ 4.32709261 8.00493326 9.96124294]]
[[ nan 10.66054365 2.95677408]
[11.53789218 2.70725911 nan]
[ nan 4.75230895 nan]]]
[[[11.23856096 8.45979355 7.53268864]
[ 6.14703327 11.90052117 nan]
[ 2.16777734 nan 7.75420214]]
[[10.04726853 11.44895046 nan]
[ nan 3.69735083 nan]
[11.29782542 2.58494314 9.8714708 ]]
[[ 7.9356587 1.48053473 9.71362122]
[ 5.11866341 3.43895455 6.86491947]
[ 8.33774813 5.66923131 2.27884056]]]
[[[ 0.75091443 2.02917445 5.68207987]
[ 4.58299978 nan 9.13853129]
[10.60912932 nan nan]]
[[ 0.54393627 8.02341744 11.69489975]
[ 9.09878676 10.60836714 2.41188805]
[ nan 6.12284334 nan]]
[[ 5.84489355 nan 1.65080169]
[ 2.75161562 1.05154552 0.17804374]
[ 3.3166642 10.74081484 5.13601563]]]]