Python分类器Sklearn

时间:2016-04-14 21:08:11

标签: python scikit-learn

我是Python和SKLearn的新手。我正在尝试制作一个简单的分类器,但我遇到了一个问题。我一直在关注一些不同的教程,但在尝试使用.fit方法时遇到错误。我是这个概念的新手并且已经尝试了文档,但发现很难理解,任何人都可以帮我解决错误或指出我正确的方向。

我在错误背后的想法是,这个值超出了dtype的范围,因为我已经转换了所有缺失的值或nan值,但错误仍在出现

代码

def main():
setup_files()

imputer = Imputer()

#the training data minus id and type:
t_num_data = load_csv(training_set_file_path, range(1, 17))
t_num_data_imputed = imputer.fit_transform(t_num_data)
print(t_num_data_imputed)

#the training type column
t_type_col = load_csv(training_set_file_path, 17, dtype=np.dtype((str, 5)))
#the query data minus id and type:
q_data = load_csv(queries_file_path, range(1, 17))
#the query id column
q_id = load_csv(queries_file_path, 0, dtype=np.dtype((str, 10)))


#fit data above to DTC and predict import
model = tree.DecisionTreeClassifier(criterion='entropy')
model.fit_transform(t_num_data, t_type_col)
predictions = model.predict(q_data)


#output the predictions:
with open(solutions_file_path, 'w') as f:
    for i in range(len(predictions)):
        f.write("{},{}\n".format(q_id[i], predictions[i]))


#fit data above to DTC and predict import
model = tree.DecisionTreeClassifier(criterion='entropy')
model.fit(t_num_data, t_type_col)
predictions = model.predict(q_data)


#output the predictions:
with open(solutions_file_path, 'w') as f:
    for i in range(len(predictions)):
        f.write("{},{}\n".format(q_id[i], predictions[i]))

错误

Traceback (most recent call last):
  File "/Users/Rory/Desktop/classifier.py", line 71, in <module>
main()
  File "/Users/Rory/Desktop/classifier.py", line 60, in main
model.fit_transform(t_num_data, t_type_col)
  File "/Users/Rory/anaconda/lib/python2.7/site-packages/sklearn/base.py", line 458, in fit_transform
return self.fit(X, y, **fit_params).transform(X)
  File "/Users/Rory/anaconda/lib/python2.7/site-packages/sklearn/tree/tree.py", line 154, in fit
    X = check_array(X, dtype=DTYPE, accept_sparse="csc")
  File "/Users/Rory/anaconda/lib/python2.7/site-packages/sklearn/utils/validation.py", line 398, in check_array
_assert_all_finite(array)
  File "/Users/Rory/anaconda/lib/python2.7/site-packages/sklearn/utils/validation.py", line 54, in _assert_all_finite
" or a value too large for %r." % X.dtype)
ValueError: Input contains NaN, infinity or a value too large for dtype('float32').

1 个答案:

答案 0 :(得分:1)

问题是你的NaN值。有很多方法可以估算NaNs。你可以尝试:

t_num_data.fillna(0)

将使用0填充所有缺失值,然后您的分类器将起作用,但可能不是很准确。还有其他方法采用平均值,基于最近邻居的估计等。但这应该让你的代码现在正常工作。