矢量化以下python代码?

时间:2018-01-17 04:16:32

标签: python arrays loops numpy

我正在尝试使用python中的两个矩阵来矢量化以下操作。

f= matrix([[ 96],
    [192],
    [288],
    [384]], dtype=int32)

g = matrix([[   0.],
    [  70.],
    [ 200.],
    [  60.]])

需要创建z而不创建循环,使得z是第一列的累积和与z的最后一个值和另一个矩阵g之和的最大值。这个循环被称为数千次,因此减慢了运行时间。

for i in range(4):
if i != 0:
    z[i] = max(f[i], z[i-1] + g[i])
else:
    z[0] = f[i]

有关如何对此代码进行矢量化的任何指导都非常有用。

提前致谢。

1 个答案:

答案 0 :(得分:0)

这是一个矢量化版本。它使用maximumf之间差异的累积cumsum(g)来预测f[i]大于z[i]的点:

时序:

N = 10
loopy                 0.00594156 ms
vect                  0.03193051 ms
N = 100
loopy                 0.05560229 ms
vect                  0.03186400 ms
N = 1000
loopy                 0.57484017 ms
vect                  0.04492043 ms
N = 10000
loopy                 5.75115310 ms
vect                  0.15519847 ms
N = 100000
loopy                57.30253551 ms
vect                  1.69428380 ms

代码:

import numpy as np

import types
from timeit import timeit

def setup_data(N):
    g = np.random.random((N,))
    f = 2 + np.cumsum(np.random.random(N,))
    return f, g

def f_loopy(f, g):
    N, = f.shape
    z = np.empty_like(f)
    for i in range(N):
        if i != 0:
            z[i] = max(f[i], z[i-1] + g[i])
        else:
            z[0] = f[i]
    return z

def f_vect(f, g):
    N, = f.shape
    gg = np.cumsum(g)
    rmx = np.maximum.accumulate(f - gg)
    sw = np.r_[0, 1 + np.flatnonzero(rmx[:-1] != rmx[1:]), N]
    return gg + np.repeat(f[sw[:-1]]-gg[sw[:-1]], np.diff(sw))

for N in [10, 100, 1000, 10000, 100000]:
    data = setup_data(N)
    ref = f_loopy(*data)
    print(f'N = {N}')
    for name, func in list(globals().items()):
        if not name.startswith('f_') or not isinstance(func, types.FunctionType):
            continue
        try:
            assert np.allclose(ref, func(*data))
            print("{:16s}{:16.8f} ms".format(name[2:], timeit(
                'f(*data)', globals={'f':func, 'data':data}, number=100)*10))
        except:
            print("{:16s} apparently failed".format(name[2:]))