GaussianNB: - ValueError:先验的总和应为1

时间:2017-08-26 14:24:07

标签: python-2.7 machine-learning scikit-learn gaussian

我想做什么?

我正在尝试使用GaussianNB分类器训练具有10个标签的数据集,但在调整我的gaussianNB先前参数时,我收到此错误: -

文件“/home/mg/anaconda2/lib/python2.7/site-packages/sklearn/naive_bayes.py”,第367行,在_partial_fit中     提高ValueError('先验的总和应为1.') ValueError:先验的总和应为1。

此代码: -     clf = GaussianNB(priors = [0.08,0.14,0.03,0.16,0.11,0.16,0.07,0.14,0.11,0.0])

你可以看到总和显然为1,但它显示了这个错误,你能指出错误。

1 个答案:

答案 0 :(得分:6)

这看起来像是一个非常糟糕的设计决策,因为他们正在做通常的不比较浮点数的东西(what every computer scientist should know about floating-point arithmetic),这让我感到惊讶(作为sklearn通常是高质量的代码)!

(尽管使用了列表,但我没有看到任何错误的用法。文档需要一个数组,而不是像许多其他情况一样的数组,但是他们的代码正在做然而,阵列转换)

Their code

if self.priors is not None:
    priors = np.asarray(self.priors)
    # Check that the provide prior match the number of classes
    if len(priors) != n_classes:
        raise ValueError('Number of priors must match number of'
                         ' classes.')
    # Check that the sum is 1
    if priors.sum() != 1.0:
        raise ValueError('The sum of the priors should be 1.')
    # Check that the prior are non-negative
    if (priors < 0).any():
        raise ValueError('Priors must be non-negative.')
    self.class_prior_ = priors
else:
    # Initialize the priors to zeros for each class
    self.class_prior_ = np.zeros(len(self.classes_),
                                 dtype=np.float64)

所以:

  • 您提供了一个列表,但他们的代码将创建一个numpy-array
  • 因此np.sum()将用于求和
  • 在您的情况下,可能会有 fp-math相关的数字错误求和
    • 你的总和在技术上是!= 1.0;但非常接近它!
  • fp-comparison x == 1.0被认为是错误的!
    • numpy带来np.isclose(),这是通常的做法

演示:

import numpy as np

priors = np.array([0.08, 0.14, 0.03, 0.16, 0.11, 0.16, 0.07, 0.14, 0.11, 0.0])
my_sum = np.sum(priors)
print('my_sum: ', my_sum)
print('naive: ', my_sum == 1.0)
print('safe: ', np.isclose(my_sum, 1.0))

输出:

('my_sum: ', 1.0000000000000002)
('naive: ', False)
('safe: ', True)

修改

由于我认为此代码效果不佳,我发布了一个问题here,您可以查看该问题是否符合要求。

numpy.random.sample(),它也采用了这样一个向量,实际上也在进行fp安全方法(数值上更稳定的求和+ epsilon-check;但不使用np.isclose())如{{3} }。