numpy.random.multinomial坏输出?

时间:2014-04-24 00:48:16

标签: python numpy random numerical-stability

我有这个功能:

import numpy as np 
def unhot(vec):
    """ takes a one-hot vector and returns the corresponding integer """
    assert np.sum(vec) == 1    # this assertion shouldn't fail, but it did...
    return list(vec).index(1)

我打电话给输出:

numpy.random.multinomial(1, coe)

当我运行它时,我得到了一个断言错误。这怎么可能? numpy.random.multinomial的输出是否保证是单热矢量?

然后我删除了断言错误,现在我已经:

ValueError: 1 is not in list

我缺少一些精美的印刷品,或者这只是破损了吗?

1 个答案:

答案 0 :(得分:1)

嗯,这就是问题,我应该意识到,因为我之前遇到过它:

np.random.multinomial(1,A([  0.,   0.,  np.nan,   0.]))

返回

array([0,                    0, -9223372036854775807,0])

我正在使用一个不稳定的softmax实现给Nans。 现在,我试图确保我通过多项式的参数总和< = 1,但我这样做了:

coe = softmax(coeffs)
while np.sum(coe) > 1-1e-9:
    coe /= (1+1e-5)

并且在那里使用NaNs,我认为while语句永远不会被触发。