为什么numpy.random.dirichlet()不接受多维数组?

时间:2013-04-10 01:31:59

标签: python numpy

numpy page上,他们举了

的例子
s = np.random.dirichlet((10, 5, 3), 20)

这一切都很好而且很棒;但是如果你想从二维alpha阵列生成随机样本呢?

alphas = np.random.randint(10, size=(20, 3))

如果你尝试np.random.dirichlet(alphas)np.random.dirichlet([x for x in alphas])或者np.random.dirichlet((x for x in alphas)),它会产生一个 ValueError: object too deep for desired array。似乎唯一有用的是:

y = np.empty(alphas.shape)
for i in xrange(np.alen(alphas)):
    y[i] = np.random.dirichlet(alphas[i])
    print y

...这对我的代码结构来说远非理想。为什么会出现这种情况,任何人都可以想到一种更像“笨拙”的做法吗?

提前致谢。

2 个答案:

答案 0 :(得分:4)

编写

np.random.dirichlet以生成单个Dirichlet分布的样本。该代码是根据Gamma分布实现的,并且该实现可以用作矢量化代码的基础,以从不同分布生成样本。在下文中,dirichlet_sample采用具有形状(n,k)的数组alphas,其中每行是Dirichlet分布的alpha向量。它返回一个也带有形状(n,k)的数组,每一行都是来自alphas的相应分布的样本。当作为脚本运行时,它使用dirichlet_samplenp.random.dirichlet生成样本以验证它们是否生成相同的样本(达到正常的浮点差异)。

import numpy as np


def dirichlet_sample(alphas):
    """
    Generate samples from an array of alpha distributions.

    `alphas` must be a numpy array with shape (n, k).
    """
    r = np.random.standard_gamma(alphas)
    r /= r.sum(-1).reshape(-1, 1)
    return r


if __name__ == "__main__":
    alphas = 2 ** np.random.randint(0, 4, size=(6, 3))

    np.random.seed(1234)
    d1 = dirichlet_sample(alphas)
    print "dirichlet_sample:"
    print d1

    np.random.seed(1234)
    d2 = np.empty(alphas.shape)
    for k in range(len(alphas)):
        d2[k] = np.random.dirichlet(alphas[k])
    print "np.random.dirichlet:"
    print d2

    # Compare d1 and d2:
    err = np.abs(d1 - d2).max()
    print "max difference:", err

示例运行:

dirichlet_sample:
[[ 0.38980834  0.4043844   0.20580726]
 [ 0.14076375  0.26906604  0.59017021]
 [ 0.64223074  0.26099934  0.09676991]
 [ 0.21880145  0.33775249  0.44344606]
 [ 0.39879859  0.40984454  0.19135688]
 [ 0.73976425  0.21467288  0.04556287]]
np.random.dirichlet:
[[ 0.38980834  0.4043844   0.20580726]
 [ 0.14076375  0.26906604  0.59017021]
 [ 0.64223074  0.26099934  0.09676991]
 [ 0.21880145  0.33775249  0.44344606]
 [ 0.39879859  0.40984454  0.19135688]
 [ 0.73976425  0.21467288  0.04556287]]
max difference: 5.55111512313e-17

答案 1 :(得分:2)

我认为你正在寻找

y = np.array([np.random.dirichlet(x) for x in alphas])

列表理解。否则你只是传递一个python列表或元组。我想象numpy.random.dirichlet不接受你的alpha值列表的原因是因为它没有被设置为 - 它已经接受了一个数组,根据文档,它预计具有k的维度。