如何有效地传递函数?

时间:2018-01-14 11:49:33

标签: python algorithm python-2.7 numpy scipy

动机

看看下面的图片。

enter image description here

给出红色,蓝色和绿色曲线。我想在x轴上的每个点找到主导曲线。这显示为图中的黑色图形。从红色,绿色和蓝色曲线的属性(在一段时间后增加和恒定),这可以归结为在右手边找到主导曲线,然后向左侧移动,找到所有交叉点并更新主导曲线。

这个概述的问题应该T次解决。这个问题有一个最后的转折点。下一次迭代的蓝色,绿色和红色曲线是通过前一次迭代的主导解决方案加上一些变化的参数构建的。作为上图中的示例:解决方案是黑色功能。此功能用于生成新的蓝色,绿色和红色曲线。然后问题又开始找到这些新曲线的主导者等。

简而言之问题
在每次迭代中,我从固定的右手侧开始,并评估所有三个函数,以查看哪个是主导函数。这种评估在迭代中花费的时间越来越长。 我的感觉是,我没有最佳地通过旧的支配功能来构建新的蓝色,绿色和红色曲线。原因:我在早期版本中遇到了最大递归深度错误。代码的其他部分也需要当前支配函数的值(绿色,红色或蓝色曲线必不可少)迭代越来越长。

对于5次迭代,只评估右侧一点上的函数增长:

结果是通过

产生的
test = A(5, 120000, 100000) 

然后运行

test.find_all_intersections()

>>> test.find_all_intersections()
iteration 4
to compute function values it took
0.0102479457855
iteration 3
to compute function values it took
0.0134601593018
iteration 2
to compute function values it took
0.0294270515442
iteration 1
to compute function values it took
0.109843969345
iteration 0
to compute function values it took
0.823768854141

我想知道为什么会这样,如果可以更有效地编程它。

详细代码说明

我很快总结了最重要的功能。完整的代码可以在下面找到。如果对代码有任何其他问题,我非常乐意详细说明/澄清。

  1. 方法u:用于生成新批次的重复任务 上面的绿色,红色和蓝色曲线我们需要旧的主导曲线。 u是在第一次迭代中使用的初始化。

  2. 方法_function_template:该函数生成的版本 通过使用不同的参数,绿色,蓝色和红色曲线。它回来了 单个输入的功能。

  3. 方法eval:这是每次生成蓝色,绿色和红色版本的核心功能。每次迭代需要三个不同的参数:vfunction,这是前一步骤的主导函数,ms,它们是影响结果曲线形状的两个参数(flaots)。其他参数在每次迭代中都是相同的。在代码中,每次迭代都有ms的样本值。对于更令人讨厌的人:它是近似一个积分,其中ms是潜在正态分布的预期平均值和标准差。近似是通过Gauss-Hermite节点/权重完成的。

  4. 方法find_all_intersections:这是查找的核心方法 每次迭代都是主导者。它构成了一个主导 通过蓝色,绿色和红色的片段连接起作用 曲线。这是通过函数piecewise

  5. 实现的

    这是完整的代码

    import numpy as np
    import pandas as pd
    from scipy.optimize import brentq
    import multiprocessing as mp
    import pathos as pt
    import timeit
    import math
    class A(object):
        def u(self, w):
            _w = np.asarray(w).copy()
            _w[_w >= 120000] = 120000
            _p = np.maximum(0, 100000 - _w)
            return _w - 1000*_p**2
    
        def __init__(self, T, upper_bound, lower_bound):
            self.T = T
            self.upper_bound = upper_bound
            self.lower_bound = lower_bound
    
        def _function_template(self, *args):
            def _f(x):
                return self.evalv(x, *args)
            return _f
    
        def evalv(self, w, c, vfunction, g, m, s, gauss_weights, gauss_nodes):
            _A = np.tile(1 + m + math.sqrt(2) * s * gauss_nodes, (np.size(w), 1))
            _W = (_A.T * w).T
            _W = gauss_weights * vfunction(np.ravel(_W)).reshape(np.size(w),
                                                                 len(gauss_nodes))
            evalue = g*1/math.sqrt(math.pi)*np.sum(_W, axis=1)
            return c + evalue
    
        def find_all_intersections(self):
    
            # the hermite gauss weights and nodes for integration
            # and additional paramters used for function generation
    
            gauss = np.polynomial.hermite.hermgauss(10)
            gauss_nodes = gauss[0]
            gauss_weights = gauss[1]
            r = np.asarray([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                            1., 1., 1., 1., 1., 1., 1., 1., 1.])
            m = [[0.038063407778193614, 0.08475713587463352, 0.15420895520972322],
                 [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
                 [0.03836174909668277, 0.08543620707856969, 0.15548297423808233],
                 [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
                 [0.038063407778193614, 0.08475713587463352, 0.15420895520972322],
                 [0.038063407778193614, 0.08475713587463352, 0.15420895520972322],
                 [0.03836174909668277, 0.08543620707856969, 0.15548297423808233],
                 [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
                 [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
                 [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
                 [0.038063407778193614, 0.08475713587463352, 0.15420895520972322],
                 [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
                 [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
                 [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
                 [0.03836174909668277, 0.08543620707856969, 0.15548297423808233],
                 [0.038063407778193614, 0.08475713587463352, 0.15420895520972322],
                 [0.038063407778193614, 0.08475713587463352, 0.15420895520972322],
                 [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
                 [0.03836174909668277, 0.08543620707856969, 0.15548297423808233],
                 [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
                 [0.038212567720998125, 0.08509661835487026, 0.15484578903763624]]
    
            s = [[0.01945441966324046, 0.04690600929081242, 0.200125178687699],
                 [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
                 [0.019529101011406914, 0.04708607140891122, 0.20089341636351565],
                 [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
                 [0.01945441966324046, 0.04690600929081242, 0.200125178687699],
                 [0.01945441966324046, 0.04690600929081242, 0.200125178687699],
                 [0.019529101011406914, 0.04708607140891122, 0.20089341636351565],
                 [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
                 [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
                 [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
                 [0.01945441966324046, 0.04690600929081242, 0.200125178687699],
                 [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
                 [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
                 [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
                 [0.019529101011406914, 0.04708607140891122, 0.20089341636351565],
                 [0.01945441966324046, 0.04690600929081242, 0.200125178687699],
                 [0.01945441966324046, 0.04690600929081242, 0.200125178687699],
                 [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
                 [0.019529101011406914, 0.04708607140891122, 0.20089341636351565],
                 [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
                 [0.019491796104351332, 0.04699612658674578, 0.20050966545654142]]
    
            self.solution = []
    
            n_cpu = mp.cpu_count()
            pool = pt.multiprocessing.ProcessPool(n_cpu)
    
            # this function is used for multiprocessing
            def call_f(f, x):
                return f(x)
    
            # this function takes differences for getting cross points
            def _diff(f_dom, f_other):
                def h(x):
                    return f_dom(x) - f_other(x)
                return h
    
            # finds the root of two function
            def find_roots(F, u_bound, l_bound):
                    try:
                        sol = brentq(F, a=l_bound,
                                     b=u_bound)
                        if np.absolute(sol - u_bound) > 1:
                            return sol
                        else:
                            return l_bound
                    except ValueError:
                        return l_bound
    
            # piecewise function
            def piecewise(l_comp, l_f):
                def f(x):
                    _ind_f = np.digitize(x, l_comp) - 1
                    if np.isscalar(x):
                        return l_f[_ind_f](x)
                    else:
                        return np.asarray([l_f[_ind_f[i]](x[i])
                                           for i in range(0, len(x))]).ravel()
                return f
    
            _u = self.u
    
            for t in range(self.T-1, -1, -1):
                print('iteration' + ' ' + str(t))
    
                l_bound, u_bound = 0.5*self.lower_bound, self.upper_bound
                l_ordered_functions = []
                l_roots = []
                l_solution = []
    
                # build all function variations
    
                l_functions = [self._function_template(0, _u, r[t], m[t][i], s[t][i],
                                                       gauss_weights, gauss_nodes)
                               for i in range(0, len(m[t]))]
    
                # get the best solution for the upper bound on the very
                # right hand side of wealth interval
    
                array_functions = np.asarray(l_functions)
                start_time = timeit.default_timer()
                functions_values = pool.map(call_f, array_functions.tolist(),
                                            len(m[t]) * [u_bound])
                elapsed = timeit.default_timer() - start_time
                print('to compute function values it took')
                print(elapsed)
    
                ind = np.argmax(functions_values)
                cross_points = len(m[t]) * [u_bound]
                l_roots.insert(0, u_bound)
                max_m = m[t][ind]
                l_solution.insert(0, max_m)
    
                # move from the upper bound twoards the lower bound
                # and find the dominating solution by exploring all cross
                # points.
    
                test = True
    
                while test:
                    l_ordered_functions.insert(0, array_functions[ind])
                    current_max = l_ordered_functions[0]
    
                    l_c_max = len(m[t]) * [current_max]
                    l_u_cross = len(m[t]) * [cross_points[ind]]
    
                    # Find new cross points on the smaller interval
    
                    diff = pool.map(_diff, l_c_max, array_functions.tolist())
                    cross_points = pool.map(find_roots, diff,
                                            l_u_cross, len(m[t]) * [l_bound])
    
                    # update the solution, cross points and current
                    # dominating function.
    
                    ind = np.argmax(cross_points)
                    l_roots.insert(0, cross_points[ind])
                    max_m = m[t][ind]
                    l_solution.insert(0, max_m)
    
                    if cross_points[ind] <= l_bound:
                        test = False
    
                l_ordered_functions.insert(0, l_functions[0])
                l_roots.insert(0, 0)
                l_roots[-1] = np.inf
    
                l_comp = l_roots[:]
                l_f = l_ordered_functions[:]
    
                # build piecewise function which is used for next
                # iteration.
    
                _u = piecewise(l_comp, l_f)
                _sol = pd.DataFrame(data=l_solution,
                                    index=np.asarray(l_roots)[0:-1])
                self.solution.insert(0, _sol)
            return self.solution
    

2 个答案:

答案 0 :(得分:4)

让我们从更改代码开始输出当前迭代:

_u = self.u
for t in range(0, self.T):
    print(t)
    lparams = np.random.randint(self.a, self.b, 6).reshape(3, 2).tolist()
    functions = [self._function_template(_u, *lparams[i])
                 for i in range(0, 3)]
    # evaluate functions
    pairs = list(itertools.combinations(functions, 2))
    fval = [F(diff(*pairs[i]), self.a, self.b) for i in range(0, 3)]
    ind = np.sort(np.unique(np.random.randint(self.a, self.b, 10)))
    _u = _temp(ind, np.asarray(functions)[ind % 3])

查看导致行为的行,

fval = [F(diff(*pairs[i]), self.a, self.b) for i in range(0, 3)]

感兴趣的功能是Fdiff。后者很简单,前者:

def F(f, a, b):
    try:
        brentq(f, a=a, b=b)
    except ValueError:
        pass

嗯,吞咽异常,让我们看看如果我们发生了什么:

def F(f, a, b):
    brentq(f, a=a, b=b)

对于第一个函数和第一个迭代,立即抛出错误:

  

ValueError:f(a)和f(b)必须有不同的符号

查看docs这是根查找功能brentq的先决条件。让我们再次更改定义,以便在每次迭代时监控这个条件。

def F(f, a, b):
    try:
        brentq(f, a=a, b=b)
    except ValueError as e:
        print(e)

输出

i
f(a) and f(b) must have different signs
f(a) and f(b) must have different signs
f(a) and f(b) must have different signs

表示i范围从0到57.意思是,函数F第一次执行任何实际工作都是i=58。并且它一直为i的更高值而这样做。

结论:这些较高的值需要更长的时间,因为:

  1. 根本不会计算较低的值
  2. i>58
  3. 的计算次数呈线性增长

答案 1 :(得分:3)

您的代码实在太复杂,无法解释您的问题 - 争取更简单的事情。有时您必须编写代码才能证明问题。

我只是根据你的描述而不是你的代码进行了攻击(尽管我运行了代码并经过验证)。这是你的问题:

  

方法eval:这是生成蓝色,绿色和绿色的核心功能   每次都是红色版本。它每个需要三个不同的参数   迭代:vfunction是主导的函数   前一步,m和s是影响的两个参数(flaots)   得到的曲线的形状。

您的vfunction参数在每次迭代时都会更复杂。您正在传递在先前迭代中构建的嵌套函数,这会导致递归执行。每次迭代都会增加递归调用的深度。

你怎么能避免这种情况?没有简单或内置的方式。最简单的答案是 - 假设这些函数的输入是一致的 - 存储函数结果(即数字)而不是函数本身。只要您拥有有限数量的已知输入,就可以执行此操作。

如果基础功能的输入不一致,那么没有捷径。您需要反复评估这些基础功能。我看到你正在对底层函数进行一些分段拼接 - 你可以测试这样做的成本是否超过了简单地获取每个底层函数的max的成本。

我跑的测试(10次迭代)花了几秒钟。我不认为这是一个问题。