SGDClassifier从稀疏数据集生成Keyerror

时间:2013-08-30 14:40:28

标签: python pandas scikit-learn

我正在进行一些文本分析,使用Pandas提取数据。

X = pd.read_Csv('../data/training.tsv', sep ='\t', na_values=['?'])
X['json'] = X['json'].apply(json.loads)
extractBody = lambda x: x['body'] if x.has_key('body') and x['body'] is not None else u'empty'
X_all['body'] = X['json'].map(extractBody)

我把它扔进一个scikit-learn向量,将tf-idf加权步骤分开:

body_counter = CountVectorizer()
body_counts = body_counter.fit_transform(X_all['body'])
body_transform = TfidfTransformer()
body_counts = body_tranform.fit_transform(body_counts)

我想使用SGDClassifier来预测某种意义上的“垃圾邮件”/“非垃圾邮件”的简单二进制分类。

model - SGDClassifier(n_iter = 5, loss = log)
model.fit(body_counts, labels)

运行时,fit方法会生成以下KeyError:

...
return self.index.get_value(self,key)
...
return self._engine.get)value(series, key)

File "index.pyx", line 96, in pandas.index.INdexEngine.get_value (pandas/index.c:2873)
File "index.pyx", line 104, in pandas.index.IndexEngine.get_value (pandas/index.c:2685)
File "index.pyx", line 148, in pandas.index.IndexEngine.get_loc (pandas/index.c:3422)
File "hashtable.pyx", line 382, in pandas.hashtable.Int64HashTable.get_item (pandas/hashtable.c:6570)
File "hashtable.pyx", line 388, in pandas.hashtable.Int64HashTable.get_item (pandas/hashtable.c:6511)
KeyError: 0

我不确定这里发生了什么。当我只想交叉验证它时,这个模型工作正常(cross_val_score)。我可以在scikit中使用naive_bayes或TruncatedSVD运行此数据集。这只有在我尝试适合这个模型时才会发生,我不知道为什么。

我该如何解决这个问题?或者我正在查看scikit中的错误学习?

修改

是的,不幸的是我不得不将我的代码重写为这篇文章而不是复制,因此可能存在一些错误。我在没有wifi连接的笔记本电脑上编码。

X.shape = 7396, 105273
labels.len() = 7395
labels type = 'pandas.core.series.Series'

...我将标签转换为numpy数组,然后就完成了!

仍然让我感到困惑的是cross_val_score会按原样接受标签,但是model.fit不会。

谢谢!

1 个答案:

答案 0 :(得分:3)

我遇到了同样的问题。供将来参考:只需拨打labels.values,即可将pandas系列标签转换为numpy数组。