如何从numpy数组的时间点创建一个掩码?

时间:2012-08-03 19:23:14

标签: python arrays numpy mask

数据是一个包含2500个测量时间序列的矩阵。我需要平均每个时间序列的时间,丢弃在尖峰周围记录的数据点(在区间tspike-dt * 10 ... tspike + 10 * dt)。尖峰时间的数量对于每个神经元是可变的并且存储在具有2500个条目的字典中。我当前的代码迭代神经元和尖峰时间,并将屏蔽值设置为NaN。然后调用bottleneck.nanmean()。但是这个代码在当前版本中会变慢,我想知道有更快的解决方案。谢谢!

import bottleneck
import numpy as np
from numpy.random import rand, randint

t = 1
dt = 1e-4
N = 2500
dtbin = 10*dt

data = np.float32(ones((N, t/dt)))
times = np.arange(0,t,dt)
spiketimes = dict.fromkeys(np.arange(N))
for key in spiketimes:
  spiketimes[key] = rand(randint(100))

means = np.empty(N)

for i in range(N):        
  spike_times = spiketimes[i]
  datarow = data[i]
  if len(spike_times) > 0:
    for spike_time in spike_times:                        
      start=max(spike_time-dtbin,0)
      end=min(spike_time+dtbin,t)
      idx = np.all([times>=start,times<=end],0)
      datarow[idx] = np.NaN
  means[i] = bottleneck.nanmean(datarow)

2 个答案:

答案 0 :(得分:0)

您可以直接索引所需的值并使用nanmean,而不是使用mean

means[i] = data[ (times<start) | (times>end) ].mean()

如果我误解了你确实需要你的索引,你可以尝试

means[i] = data[numpy.logical_not( np.all([times>=start,times<=end],0) )].mean()

同样在代码中你可能不想使用if len(spike_times) > 0(我假设你在每次迭代时删除了尖峰时间,否则该语句将永远为真,你将有一个无限循环),只使用{ {1}}。

答案 1 :(得分:0)

代码中的绝大部分处理时间来自这一行:

idx = np.all([times>=start,times<=end],0)

这是因为对于每个峰值,您将每个值与开始和结束进行比较。由于您在此示例中有统一的时间步长(我假设您的数据也是如此),因此简单地计算开始和结束索引要快得多:

# This replaces the last loop in your example:
for i in range(N):        
    spike_times = spiketimes[i]
    datarow = data[i]
    if len(spike_times) > 0:
        for spike_time in spike_times:
            start=max(spike_time-dtbin,0)
            end=min(spike_time+dtbin,t)
            #idx = np.all([times>=start,times<=end],0)
            #datarow[idx] = np.NaN
            datarow[int(start/dt):int(end/dt)] = np.NaN
    ## replaced this with equivalent for testing
    means[i] = datarow[~np.isnan(datarow)].mean()  

这将我的运行时间从约100秒减少到约1.5秒。 您还可以通过在spike_times上对循环进行矢量化来节省更多时间。这种影响取决于数据的特征(对于高峰值率应该是最有效的):

kernel = np.ones(20, dtype=bool)
for i in range(N):        
    spike_times = spiketimes[i]
    datarow = data[i]
    mask = np.zeros(len(datarow), dtype=bool)
    indexes = (spike_times / dt).astype(int)
    mask[indexes] = True  
    mask = np.convolve(mask, kernel)[10:-9]

    means[i] = datarow[~mask].mean()