def sample(preds, temperature=1.0):
preds = np.asarray(preds).astype('float64')
preds = np.log(preds) / temperature
exp_preds = np.exp(preds)
preds = exp_preds / np.sum(exp_preds)
probas = np.random.multinomial(1, preds, 1)
return np.argmax(probas)
这是为了简化上述代码:
def sample(p, temperature=1.0):
p = np.exp(np.log(p) / temperature)
p = np.random.multinomial(1, p / p.sum(), 1)
return np.argmax(p)
然而第二个失败了这个错误:
File "z.py", line 75, in sample
p = np.random.multinomial(1, p / p.sum(), 1)
File "mtrand.pyx", line 4593, in mtrand.RandomState.multinomial (numpy/random/mtrand/mtrand.c:37541)
ValueError: sum(pvals[:-1]) > 1.0
这怎么可能?
答案 0 :(得分:-1)
我的系统是64位的,所以numpy的默认dtype是float64,但是进入这个函数的一些输入是32位浮点数,所以在那里混合两种数据类型会导致错误。