具有取决于先前值的运行值的矢量化循环(+ if语句)

时间:2018-11-27 03:13:39

标签: python numba

我习惯于用Python编写矢量化语句和列表解析,但是在依赖循环中先前值的“运行”计算以及if语句中都出现了一个问题。示意图如下:

def my_loop(x, a=0.5, b=0.9):
  out = np.copy(x)
  prev_val = 0 
  for i in np.arange(x.shape[0]):

      if x[i] < prev_val:
          new_val = (1-a)*x[i] + a*prev_val
      else:
          new_val = (1-b)*x[i] + b*prev_val

      out[i] = new_val

      prev_val = new_val

  return out

我还无法弄清楚如何将其向量化(例如,通过使用某种累加器),所以我会问:有没有办法使它更加Pythonic /更快?

我以前见过有关在有一个if语句时进行向量化的文章-通常通过np.where()解决-但是没有一个地方有一个取决于其先前状态的“运行中”值...所以我没有尚未发现任何重复的问题(this one与通常意义上的矢量化无关,this one与“先前值”有关,但涉及列表索引)。

到目前为止,我已经尝试使用np.vectorize和numba的@jit,它们的运行速度确实有所提高,但是都无法满足我的期望。有什么我想念的吗? (也许带有map()的东西?)谢谢。

(是的,在a = b的情况下这很容易!)

2 个答案:

答案 0 :(得分:1)

在nopython模式下进行JIT的速度更快。引用numba文档:

  

Numba有两种编译模式:nopython模式和对象模式。的   前者会产生更快的代码,但有一些局限性   Numba退回后者。为了防止Numba退回,   并引发错误,请传递nopython = True。

@nb.njit(cache=True)
def my_loop5(x, a=0.5, b=0.9):
  out = np.zeros(x.shape[0],dtype=x.dtype)

  for i in range(x.shape[0]):
      if x[i] < out[i-1]:
          out[i] = (1-a) * x[i] + a * out[i-1]
      else:
          out[i] = (1-b) * x[i] + b * out[i-1]
  return out

因此打开:

x = np.random.uniform(low=-5.0, high=5.0, size=(1000000,))

时间是:

  

my_loop4:0.235秒

     

my_loop5:0.193秒

HTH。

答案 1 :(得分:0)

我意识到,通过删除虚拟变量,可以将此代码放入numba和from numba import jit, autojit @autojit def my_loop4(x, a=0.5, b=0.9): out = np.zeros(x.shape[0],dtype=x.dtype) for i in np.arange(x.shape[0]): if x[i] < out[i-1]: out[i] = (1-a)*x[i] + a*out[i-1] else: out[i] = (1-b)*x[i] + b*out[i-1] return out 可以发挥其魔力并使它“快速”的形式:

    int PERIODIC_SYNC_JOB_ID = 0;
    long interval  = 1000 * 60 * 20;
    JobInfo.Builder builder = new JobInfo.Builder(PERIODIC_SYNC_JOB_ID,
            new ComponentName(getApplicationContext(), SampleJobService.class));
    JobInfo jobInfo = builder.setPeriodic(interval).build();

    JobScheduler jobScheduler = (JobScheduler) getApplicationContext().getSystemService(Context.JOB_SCHEDULER_SERVICE);
    jobScheduler.schedule(jobInfo);

    Log.d("JobScheduler", "Sample job is scheduled every " + interval + " ms");

没有@autojit,这仍然很慢。但是,随着它,...问题解决了。因此,删除不必要的变量 并添加@autojit就是解决问题的办法。