我有一个仿真,该仿真在给定的循环中任意多次查看numpy数组,以检查是否有任何元素超过了某个阈值。如果某个元素超过了阈值,那么我需要跟踪它是哪个元素,因此我可以对该特定元素进行操作。我有一个函数可以做到这一点,但这是我的代码的主要瓶颈。运行仿真所花费的时间中,约有90%用于执行此功能。
这是我的功能:
def scanLattice(s_array,t_array,L):
failures = []
for i in xrange(L):
for j in xrange(L):
if s_array[i,j] >= t_array[i,j]:
M = L*j + i
failures.append(M)
return failures
s_array
是将检查其值的数组;它的大小为[L,L]。 t_array
也是[L,L],并保留用于检查s_array
中的值的阈值。阈值不一致。 t_array的随机1%元素具有与数组其余部分不同的阈值,这些阈值是一致的。我跟踪这些不一致的站点。因此,我遍历s_array
的行和列,并检查t_array
的相应元素,如果满足阈值条件,则将该站点的索引添加到列表中。
对于如何以更有效的方式重写此功能的任何建议,将不胜感激。
答案 0 :(得分:0)
您在函数中输入的实际上是一个展平数组中的索引列表,该数组作为多维数组的一维视图。只需计算s_array大于或等于t_array的值的掩码即可。接下来,使用np.flatnonzero()对非假索引隐式展平掩码。还有一个问题。值L * j + i是转置数组中的扁平索引,因此应用ndarray.T
np.flatnonzero((s_array>=t_array).T)