SciKit-learn的'预测'功能以错误的格式输出

时间:2015-08-10 15:41:31

标签: python machine-learning scikit-learn text-classification

我是scikit的新手,所以玩弄它。

有关问题的背景: 我正试图在hackerRank上玩'Byte正确的苹果'比赛。 其中我们给出了两个文件,一个包含苹果公司的文本和一个用于苹果的文本。现在我们必须从中学习,然后对新文本进行预测。

虽然代码运行但我的问题是: - 由于'line'(在下面的代码中)是单个输入,我应该得到零或一个单位输出。但我得到一个数组作为输出。 - 我甚至接近使用下面的代码学习任何东西?

{{1}}

2 个答案:

答案 0 :(得分:2)

predict函数返回documentation中所述的数组对象。此数组对象对应于labels数组中的索引。要获得line的预测,您需要尝试以下内容:

print labels[predicted]

答案 1 :(得分:2)

我自己找到了答案。

有关

predicted = text_clf.predict(line);

'线'应该是一个列表,而不是一个字符串,因为它适用于' fit'功能

即。取代

line = 'I am talking about the product apple computer by Steve Jobs'

通过

line = [];    
line.append('I am talking about apple the fruit we eat.');

或@jme建议我们可以使用

text_clf.predict([line])