我正在尝试使用花哨的索引而不是循环来加速Numpy中的函数。据我所知,我已经正确地实现了花哨的索引版本。问题是两个函数(循环和花式索引)不会返回相同的结果。我不知道为什么。值得指出的是,如果使用较小的数组(例如,20 x 20 x 20),函数会返回相同的结果。
下面我已经包含了重现错误所需的一切。如果函数确实返回相同的结果,那么行find_maxdiff(data) - find_maxdiff_fancy(data)
应返回一个满零的数组。
from numpy import *
def rms(data, axis=0):
return sqrt(mean(data ** 2, axis))
def find_maxdiff(data):
samples, channels, epochs = shape(data)
window_size = 50
maxdiff = zeros(epochs)
for epoch in xrange(epochs):
signal = rms(data[:, :, epoch], axis=1)
for t in xrange(window_size, alen(signal) - window_size):
amp_a = mean(signal[t-window_size:t], axis=0)
amp_b = mean(signal[t:t+window_size], axis=0)
the_diff = abs(amp_b - amp_a)
if the_diff > maxdiff[epoch]:
maxdiff[epoch] = the_diff
return maxdiff
def find_maxdiff_fancy(data):
samples, channels, epochs = shape(data)
window_size = 50
maxdiff = zeros(epochs)
signal = rms(data, axis=1)
for t in xrange(window_size, alen(signal) - window_size):
amp_a = mean(signal[t-window_size:t], axis=0)
amp_b = mean(signal[t:t+window_size], axis=0)
the_diff = abs(amp_b - amp_a)
maxdiff[the_diff > maxdiff] = the_diff
return maxdiff
data = random.random((600, 20, 100))
find_maxdiff(data) - find_maxdiff_fancy(data)
data = random.random((20, 20, 20))
find_maxdiff(data) - find_maxdiff_fancy(data)
答案 0 :(得分:3)
问题在于这一行:
maxdiff[the_diff > maxdiff] = the_diff
左侧仅选择maxdiff的某些元素,但右侧包含the_diff的所有元素。这应该工作:
replaceElements = the_diff > maxdiff
maxdiff[replaceElements] = the_diff[replaceElements]
或简单地说:
maxdiff = maximum(maxdiff, the_diff)
至于为什么20x20x20尺寸似乎有用:这是因为你的窗口尺寸太大,所以没有任何东西被执行。
答案 1 :(得分:0)
首先,如果我理解正确,你的信号现在是2D - 所以我认为明确地对它进行索引会更清楚(例如,amp_a = mean(signal [t-window_size:t,:],axis = 0)。与alen(信号)类似 - 在两种情况下都应该是样本,所以我认为使用它会更清楚。
每当你在t
循环中实际执行某些操作时都是错误的 - 当在{20} 20x20示例中samples < window_lenght
时,该循环永远不会被执行。只要该循环执行多次(即samples > 2 *window_length+1
),就会出现错误。不知道为什么 - 他们确实看起来与我相同。