我想尽快评估大尺寸的np.random.dirichlet。更确切地说,我想要一个近似于下面的函数至少快10倍。根据经验,我观察到该函数的小维度版本输出一个或两个具有0.1的数量的条目,并且每个其他条目都非常小以至于它们是无关紧要的。但这一观察结果并非基于任何严格的评估。近似不需要如此准确,但我想要一些不太粗糙的东西,因为我将此噪声用于MCTS。
def g():
np.random.dirichlet([0.03]*4840)
>>> timeit.timeit(g,number=1000)
0.35117408499991143
答案 0 :(得分:1)
假设您的alpha已固定在组件上并用于多次迭代,您可以将相应的gamma分布的ppf制成表格。这可能是<div>
<div v-if="isLoggedIn()">
<v-app id="inspire">
<v-navigation-drawer fixed v-model="drawer" app >
<v-list dense v-show="user.user_type == 0">
<!-- general section-->
<v-subheader>Student</v-subheader>
<v-list-tile :to="{name : 'Competitions-students'}" exact>
<v-list-tile-action>
<v-icon>fa-trophy</v-icon>
</v-list-tile-action>
<v-list-tile-content>
<v-list-tile-title>Open competitions</v-list-tile-title>
</v-list-tile-content>
</v-list-tile>
<!-- admin section-->
</v-list>
</v-navigation-drawer>
<v-content>
<router-view></router-view>
</v-content>
</v-app>
</div>
<div v-if="!isLoggedIn()">
<v-app>
<router-view></router-view>
</v-app>
</div>
</div>
,但我们也可以使用scipy.stats.gamma.ppf
。这个功能似乎相当缓慢,所以这是一个非常重要的前期投资。
以下是一般概念的粗略实现:
scipy.special.gammaincinv
示例输出:
import numpy as np
from scipy import special
class symm_dirichlet:
def __init__(self, alpha, resolution=2**16):
self.alpha = alpha
self.resolution = resolution
self.range, delta = np.linspace(0, 1, resolution,
endpoint=False, retstep=True)
self.range += delta / 2
self.table = special.gammaincinv(self.alpha, self.range)
def draw(self, n_sampl, n_comp, interp='nearest'):
if interp != 'nearest':
raise NotImplementedError
gamma = self.table[np.random.randint(0, self.resolution,
(n_sampl, n_comp))]
return gamma / gamma.sum(axis=1, keepdims=True)
import time, timeit
t0 = time.perf_counter()
X = symm_dirichlet(0.03)
t1 = time.perf_counter()
print(f'Upfront cost {t1-t0:.3f} sec')
print('Running cost per 1000 samples of width 4840')
print('tabulated {:3f} sec'.format(timeit.timeit(
'X.draw(1, 4840)', number=1000, globals=globals())))
print('np.random.dirichlet {:3f} sec'.format(timeit.timeit(
'np.random.dirichlet([0.03]*4840)', number=1000, globals=globals())))
更好地检查它是否大致正确: