给定段和系数的快速三次样条输出计算?

时间:2017-01-16 23:21:13

标签: python numpy

我有一个具有6个段的三次样条激活函数。中断(长度= 7(6 + 1))和系数是已知的(形状=(6L,4L))。此三次样条线用于相对于断点的输入,而不是绝对值。这是我的输出计算方法,包含一些样本数据:

def CubSpline(cs,x):
   breaks=cs['breaks']
   coefs=cs['coefs']
   pieces=cs['pieces']
   if x <= breaks[0] :
      return coefs[0][3]
   elif x >= breaks[pieces] :
      ind=pieces-1
      diff=breaks[ind+1]-breaks[ind]
   else :
      ind=0
      while x > breaks[ind+1] :
         ind += 1
      diff=x-breaks[ind]
   y=coefs[ind][3]+coefs[ind][2]*diff + coefs[ind][1]*diff*diff + coefs[ind][0]*diff*diff*diff
   return y

vcubspline=np.vectorize(CubSpline)

breaks=5*np.sort(np.random.randn(7))
coefs=np.random.randn(6,4)
pieces=6
cs=dict()
cs['pieces']=pieces
cs['breaks']=breaks
cs['coefs']=coefs
arr=np.random.randint(10,size=(500,500))

start=time.clock()
a=vcubspline2(cs,arr)
print a.shape
stop=time.clock()
print stop-start

我想知道这是否是计算输出的最快方法?如何改善这个?

1 个答案:

答案 0 :(得分:0)

如评论中所建议的那样,使用numpy.piecewise会导致代码更高效,更简洁。该函数直接与数组x一起工作,创建条件数组(涉及x的不等式)和相应的函数数组,然后将所有数据传递给piecewise

def cubicSpline(cs, x):
    breaks = cs['breaks']
    coefs = cs['coefs']
    x = np.clip(x, breaks[0], breaks[-1])   # clip to the interval in which spline is defined
    conditions = [x <= b for b in breaks]
    functions = [coefs[0][3]] + [lambda x, c=c, b=b: c[3] + c[2]*(x-b) + c[1]*(x-b)**2 + c[0]*(x-b)**3 for c, b in zip(coefs, breaks)]
    y = np.piecewise(x, conditions, functions)
    return y

breaks = 5*np.sort(np.random.randn(7))
coefs = np.random.randn(6,4)
cs = {'breaks': breaks, 'coefs': coefs}
arr = np.random.randint(10, size=(500,500))
a = cubicSpline(cs, arr)

最后一行执行时间为53毫秒(从timeit开始),而原始版本则为805毫秒。

输入字典的pieces字段是多余的,因为给定的断点和系数已经具有该信息。