给定一个维度,n和布尔运算符,我想通过应用运算符来查看给定的具有适当维度的张量是否与给定维度的第(i + 1)到第(i + n)个切片相比较。
换句话说,在一个更具体的情况下,假设我有一个2维的数组,我想比较5个值。我需要创建一个布尔数组,如果接下来的5个值都高于第1个值,第1行将在第1列中为true。类似地,在列和下一行中,第1列将比较第3行到第8行的第1列中的值或第2行中的第1列。
E.g。
[[1, 1],
[2, 2],
[3, 3],
[4, 4],
[5, 5],
[6, 0],
[7, 1],
[8, 2]]
应该给:
[[True, False],
[True, False],
[True, False]]
当运算符为all >=
时(其中所有接下来的5个项应该大于或等于),并且要比较的项数为5,并且比较是逐行的(轴0)。
我想在Numpy或Pandas中这样做,但更喜欢Numpy。
答案 0 :(得分:1)
一个技巧是使用Scipy's 1D minimum filter
并将当前元素与当前元素之间的最小值进行比较,并且当前元素的长度为n
。检查该间隔中的最小值,我们基本上针对所有元素检查greater-than
。
因此,我们会有一个解决方案,就像这样 -
from scipy.ndimage.filters import minimum_filter1d as minf
def rolling_comparison(a, W):
HW = (W-1)//2 # Half window size for offsetting kernel in min filter
v = minf(a,W,origin=-HW)
return v[:,1:] > a[:,:-1]
以下是具有各种窗口大小的示例测试 -
In [245]: a
Out[245]:
array([[59, 86, 77, 31, 91, 88, 13, 77, 77, 39],
[12, 63, 98, 21, 69, 89, 93, 38, 52, 62],
[29, 58, 42, 74, 22, 27, 23, 40, 37, 11]])
In [246]: rolling_comparison(a, W=3)
Out[246]:
array([[False, False, False, False, False, False, True, False, False],
[ True, False, False, True, False, False, False, True, False],
[ True, False, False, False, True, False, False, False, False]])
In [247]: rolling_comparison(a, W=5)
Out[247]:
array([[False, False, False, False, False, False, True, False, False],
[ True, False, False, True, False, False, False, False, False],
[False, False, False, False, False, False, False, False, False]])
In [248]: rolling_comparison(a, W=7)
Out[248]:
array([[False, False, False, False, False, False, False, False, False],
[ True, False, False, True, False, False, False, False, False],
[False, False, False, False, False, False, False, False, False]])
解决案例
现在,列出的方法适用于2D数组的每一行。您似乎希望使其按列工作。此外,列出的方法reflects
边界的边界元素,而在您的情况下,您只对有效元素感兴趣。因此,为了适应您的情况,我们需要使用transpose
并剪切前半窗口大小。
因此,根据您的情况,我们会 -
In [82]: a
Out[82]:
array([[1, 1],
[2, 2],
[3, 3],
[4, 4],
[5, 5],
[6, 6], # Made the second elem as 6 for variety
[7, 1],
[8, 2]])
In [83]: rolling_comparison(a.T, W=5).T[:3] # 3 is half window size for 5
Out[83]:
array([[ True, True],
[ True, False],
[ True, False]], dtype=bool)