Sci-Kit学习SGD Algo时出错 - “数组包含NaN或无穷大”

时间:2013-09-03 18:20:45

标签: python numpy machine-learning scikit-learn nan

我收到一条错误,指出“数组包含NaN或无穷大”。我检查了我的数据,无论是火车/测试缺失值,都没有遗漏。

我可能对“数组包含NaN或无穷大”的含义有错误的解释。

import numpy as np
from sklearn import linear_model
from numpy import genfromtxt, savetxt

def main():
    #create the training & test sets, skipping the header row with [1:]
    dataset = genfromtxt(open('C:\\Users\\Owner\\training.csv','r'), delimiter=',')[0:50]    
    target = [x[0] for x in dataset]
    train = [x[1:50] for x in dataset]
    test = genfromtxt(open('C:\\Users\\Owner\\test.csv','r'), delimiter=',')[0:50]

    #create and train the SGD
    sgd = linear_model.SGDClassifier()
    sgd.fit(train, target)
    predictions = [x[1] for x in sgd.predict(test)]

    savetxt('C:\\Users\\Owner\\Desktop\\preds.csv', predictions, delimiter=',', fmt='%f')

if __name__=="__main__":
    main()

我认为数据类型可能会抛出一个循环算法(它们是浮点数)。

我知道SGD可以处理浮点数,所以我不确定这个设置是否要求我声明数据类型。

如下列之一:

>>> dt = np.dtype('i4')   # 32-bit signed integer
>>> dt = np.dtype('f8')   # 64-bit floating-point number
>>> dt = np.dtype('c16')  # 128-bit complex floating-point number
>>> dt = np.dtype('a25')  # 25-character string

以下是完整的错误消息:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-62-af5537e7802b> in <module>()
     19
     20 if __name__=="__main__":
---> 21     main()

<ipython-input-62-af5537e7802b> in main()
     13     #create and train the SGD
     14     sgd = linear_model.SGDClassifier()
---> 15     sgd.fit(train, target)
     16     predictions = [x[1] for x in sgd.predict(test)]
     17

C:\Anaconda\lib\site-packages\sklearn\linear_model\stochastic_gradient.pyc in fi
t(self, X, y, coef_init, intercept_init, class_weight, sample_weight)
    518                          coef_init=coef_init, intercept_init=intercept_i
nit,
    519                          class_weight=class_weight,
--> 520                          sample_weight=sample_weight)
    521
    522

C:\Anaconda\lib\site-packages\sklearn\linear_model\stochastic_gradient.pyc in _f
it(self, X, y, alpha, C, loss, learning_rate, coef_init, intercept_init, class_w
eight, sample_weight)
    397             self.class_weight = class_weight
    398
--> 399         X = atleast2d_or_csr(X, dtype=np.float64, order="C")
    400         n_samples, n_features = X.shape
    401

C:\Anaconda\lib\site-packages\sklearn\utils\validation.pyc in atleast2d_or_csr(X
, dtype, order, copy)
    114     """
    115     return _atleast2d_or_sparse(X, dtype, order, copy, sparse.csr_matrix
,
--> 116                                 "tocsr")
    117
    118

C:\Anaconda\lib\site-packages\sklearn\utils\validation.pyc in _atleast2d_or_spar
se(X, dtype, order, copy, sparse_class, convmethod)
     94         _assert_all_finite(X.data)
     95     else:
---> 96         X = array2d(X, dtype=dtype, order=order, copy=copy)
     97         _assert_all_finite(X)
     98     return X

C:\Anaconda\lib\site-packages\sklearn\utils\validation.pyc in array2d(X, dtype,
order, copy)
     79                         'is required. Use X.toarray() to convert to dens
e.')
     80     X_2d = np.asarray(np.atleast_2d(X), dtype=dtype, order=order)
---> 81     _assert_all_finite(X_2d)
     82     if X is X_2d and copy:
     83         X_2d = safe_copy(X_2d)

C:\Anaconda\lib\site-packages\sklearn\utils\validation.pyc in _assert_all_finite
(X)
     16     if (X.dtype.char in np.typecodes['AllFloat'] and not np.isfinite(X.s
um())
     17             and not np.isfinite(X).all()):
---> 18         raise ValueError("Array contains NaN or infinity.")
     19
     20

ValueError: Array contains NaN or infinity.

任何想法都会受到赞赏。

1 个答案:

答案 0 :(得分:0)

如错误报告,您的数据中某处有np.nannp.inf-np.inf。由于您正在从文本文件中读取并且您说数据不包含缺失值,因此可能是由列标题或文件中无法自动解释的其他值引起的。

genfromtxt的文档显示,读入数组的默认dtypefloat,这意味着您读取的所有值都必须等同于float(x)

如果您不确定这是否导致错误,您可以从numpy数组中删除非限定数字,如下所示:

dataset[ ~np.isfinite(dataset) ] = 0  # Set non-finite (nan, inf, -inf) to zero

如果这样可以消除错误,您可以确定变量某处中包含无效值。要查找 where ,您可以使用以下内容:

np.where(~np.isfinite(dataset))

这将返回无效值所在的索引列表,例如

>>> import numpy as np

>>> dataset = np.array([[0,1,1],[np.nan,0,0],[1,2,np.inf]])
>>> dataset
array([[  0.,   1.,   1.],
       [ nan,   0.,   0.],
       [  1.,   2.,  inf]])

>>> np.where(~np.isfinite(dataset))
(array([1, 2]), array([0, 2]))
相关问题