从面具中提取高效的numpy子阵列

时间:2017-04-13 07:17:08

标签: python performance numpy

我正在搜索pythonic方法,使用掩码从给定数组中提取多个子数组,如示例所示:

a = np.array([10, 5, 3, 2, 1])
m = np.array([True, True, False, True, True])

输出将是如下所示的数组的集合,其中只有掩码m的True值的连续“区域”(彼此相邻的真值)表示生成子数组的索引。

L[0] = np.array([10, 5])
L[1] = np.array([2, 1])

3 个答案:

答案 0 :(得分:2)

def splitByBool(a, m):
    if m[0]:
        return np.split(a, np.nonzero(np.diff(m))[0] + 1)[::2]
    else:
        return np.split(a, np.nonzero(np.diff(m))[0] + 1)[1::2] 

这将返回一个数组列表,在m

中拆分为True的块

答案 1 :(得分:2)

这是一种方法 -

>  C:\Windows\Microsoft.NET\Framework64\v4.0.30319\Microsoft.Common.targets(1605,5):
> warning MSB3245: Could not resolve this reference. Could not locate
> the assembly "Common". Check to make sure the assembly exists on disk.
> If this reference is required by your code, you may get compilation
> errors.
> [c:\bw\41\src\F\TFS\te\Pro\Extensions\Toto.Presentation.Extensions.Interfaces\Toto.Presentation.Extensions.Interfaces.csproj]
>              For SearchPath "{HintPathFromItem}".
>              Considered "..\..\..\..\..\..\..\..\..\..\Toto.Common.dll", but it didn't exist.
>              For SearchPath "{TargetFrameworkDirectory}".
>              Considered "C:\Program Files (x86)\Reference Assemblies\Microsoft\Framework\.NETFramework\v4.0\Toto.Common.winmd",
> but it didn't exist.
>              Considered "C:\Program Files (x86)\Reference Assemblies\Microsoft\Framework\.NETFramework\v4.0\Toto.Common.dll",
> but it didn't exist.
>              Considered "C:\Program Files (x86)\Reference Assemblies\Microsoft\Framework\.NETFramework\v4.0\Toto.Common.exe",
> but it didn't exist.

示例运行 -

def separate_regions(a, m):
    m0 = np.concatenate(( [False], m, [False] ))
    idx = np.flatnonzero(m0[1:] != m0[:-1])
    return [a[idx[i]:idx[i+1]] for i in range(0,len(idx),2)]

运行时测试

其他方法 -

In [41]: a = np.array([10, 5, 3, 2, 1])
    ...: m = np.array([True, True, False, True, True])
    ...: 

In [42]: separate_regions(a, m)
Out[42]: [array([10,  5]), array([2, 1])]

计时 -

# @kazemakase's soln
def zip_split(a, m):
    d = np.diff(m)
    cuts = np.flatnonzero(d) + 1

    asplit = np.split(a, cuts)
    msplit = np.split(m, cuts)

    L = [aseg for aseg, mseg in zip(asplit, msplit) if np.all(mseg)]
    return L

增加岛屿的平均长度 -

In [49]: a = np.random.randint(0,9,(100000))

In [50]: m = np.random.rand(100000)>0.2

# @kazemakase's's solution
In [51]: %timeit zip_split(a,m)
10 loops, best of 3: 114 ms per loop

# @Daniel Forsman's solution
In [52]: %timeit splitByBool(a,m)
10 loops, best of 3: 25.1 ms per loop

# Proposed in this post
In [53]: %timeit separate_regions(a, m)
100 loops, best of 3: 5.01 ms per loop

答案 2 :(得分:1)

听起来像np.split的自然应用。

首先必须弄清楚数组的切割位置,即掩码在TrueFalse之间变化的位置。接下来丢弃掩码为False的所有元素。

a = np.array([10, 5, 3, 2, 1])
m = np.array([True, True, False, True, True])

d = np.diff(m)
cuts = np.flatnonzero(d) + 1

asplit = np.split(a, cuts)
msplit = np.split(m, cuts)

L = [aseg for aseg, mseg in zip(asplit, msplit) if np.all(mseg)]

print(L[0])  # [10  5]
print(L[1])  # [2 1]