针对给定条件优化检查numpy数组的每个元素

时间:2019-07-19 15:16:12

标签: python numpy optimization simulation

我有一个仿真,该仿真在给定的循环中任意多次查看numpy数组,以检查是否有任何元素超过了某个阈值。如果某个元素超过了阈值,那么我需要跟踪它是哪个元素,因此我可以对该特定元素进行操作。我有一个函数可以做到这一点,但这是我的代码的主要瓶颈。运行仿真所花费的时间中,约有90%用于执行此功能。

这是我的功能:

    def scanLattice(s_array,t_array,L):
        failures = []
        for i in xrange(L):
            for j in xrange(L):
                if s_array[i,j] >= t_array[i,j]:
                    M = L*j + i
                    failures.append(M)
        return failures

s_array是将检查其值的数组;它的大小为[L,L]。 t_array也是[L,L],并保留用于检查s_array中的值的阈值。阈值不一致。 t_array的随机1%元素具有与数组其余部分不同的阈值,这些阈值是一致的。我跟踪这些不一致的站点。因此,我遍历s_array的行和列,并检查t_array的相应元素,如果满足阈值条件,则将该站点的索引添加到列表中。

对于如何以更有效的方式重写此功能的任何建议,将不胜感激。

1 个答案:

答案 0 :(得分:0)

您在函数中输入的实际上是一个展平数组中的索引列表,该数组作为多维数组的一维视图。只需计算s_array大于或等于t_array的值的掩码即可。接下来,使用np.flatnonzero()对非假索引隐式展平掩码。还有一个问题。值L * j + i是转置数组中的扁平索引,因此应用ndarray.T

np.flatnonzero((s_array>=t_array).T)