一种快速逼近大尺寸np.random.dirichlet的方法

时间:2018-02-24 05:35:07

标签: performance numpy optimization function-approximation

我想尽快评估大尺寸的np.random.dirichlet。更确切地说,我想要一个近似于下面的函数至少快10倍。根据经验,我观察到该函数的小维度版本输出一个或两个具有0.1的数量的条目,并且每个其他条目都非常小以至于它们是无关紧要的。但这一观察结果并非基于任何严格的评估。近似不需要如此准确,但我想要一些不太粗糙的东西,因为我将此噪声用于MCTS。

def g():
   np.random.dirichlet([0.03]*4840)

>>> timeit.timeit(g,number=1000)
0.35117408499991143

1 个答案:

答案 0 :(得分:1)

假设您的alpha已固定在组件上并用于多次迭代,您可以将相应的gamma分布的ppf制成表格。这可能是<div> <div v-if="isLoggedIn()"> <v-app id="inspire"> <v-navigation-drawer fixed v-model="drawer" app > <v-list dense v-show="user.user_type == 0"> <!-- general section--> <v-subheader>Student</v-subheader> <v-list-tile :to="{name : 'Competitions-students'}" exact> <v-list-tile-action> <v-icon>fa-trophy</v-icon> </v-list-tile-action> <v-list-tile-content> <v-list-tile-title>Open competitions</v-list-tile-title> </v-list-tile-content> </v-list-tile> <!-- admin section--> </v-list> </v-navigation-drawer> <v-content> <router-view></router-view> </v-content> </v-app> </div> <div v-if="!isLoggedIn()"> <v-app> <router-view></router-view> </v-app> </div> </div> ,但我们也可以使用scipy.stats.gamma.ppf。这个功能似乎相当缓慢,所以这是一个非常重要的前期投资。

以下是一般概念的粗略实现:

scipy.special.gammaincinv

示例输出:

import numpy as np
from scipy import special

class symm_dirichlet:
    def __init__(self, alpha, resolution=2**16):
        self.alpha = alpha
        self.resolution = resolution
        self.range, delta = np.linspace(0, 1, resolution,
                                        endpoint=False, retstep=True)
        self.range += delta / 2
        self.table = special.gammaincinv(self.alpha, self.range)
    def draw(self, n_sampl, n_comp, interp='nearest'):
        if interp != 'nearest':
            raise NotImplementedError
        gamma = self.table[np.random.randint(0, self.resolution,
                                             (n_sampl, n_comp))]
        return gamma / gamma.sum(axis=1, keepdims=True)

import time, timeit

t0 = time.perf_counter()
X = symm_dirichlet(0.03)
t1 = time.perf_counter()
print(f'Upfront cost {t1-t0:.3f} sec')
print('Running cost per 1000 samples of width 4840')
print('tabulated           {:3f} sec'.format(timeit.timeit(
    'X.draw(1, 4840)', number=1000, globals=globals())))
print('np.random.dirichlet {:3f} sec'.format(timeit.timeit(
    'np.random.dirichlet([0.03]*4840)', number=1000, globals=globals())))

更好地检查它是否大致正确:

enter image description here