我试图从NLTK写出混淆矩阵。 我尝试了以下示例,它运行良好。
>>> import nltk
>>> from nltk.metrics import*
>>> from nltk.corpus import brown
>>> brown_a = nltk.corpus.brown.tagged_sents()[:300]
>>> def tag_list(tagged_sents):
return [tag for sent in tagged_sents for (word, tag) in sent]
>>> tagger = nltk.UnigramTagger(brown_a)
>>> gold = tag_list(brown_a)
>>> def apply_tagger(tagger, corpus):
return [tagger.tag(nltk.tag.untag(sent)) for sent in corpus]
>>> test = tag_list(apply_tagger(tagger, brown_a)
>>> cm = nltk.ConfusionMatrix(gold, test)
>>> print cm.pretty_format(show_percents=False,values_in_chart=True,truncate=5,sort_by_count=True)
但如果我按照以下方式给出tesset,
>>> tests=nltk.corpus.brown.tagged_sents()[300:400]
>>> test = tag_list(apply_tagger(tagger, tests))
>>> cm = nltk.ConfusionMatrix(gold, test)
正在生成错误,
Traceback (most recent call last):
File "<pyshell#12>", line 1, in <module>
cm = nltk.ConfusionMatrix(gold, test)
File "C:\Python27\lib\site-packages\nltk\metrics\confusionmatrix.py", line 46, in __init__
raise ValueError('Lists must have the same length.')
ValueError: Lists must have the same length.
即使我尝试使用相同长度的测试集,
>>> test1=nltk.corpus.brown.tagged_sents()[700:1000]
>>> test = tag_list(apply_tagger(tagger, test1))
>>> cm = nltk.ConfusionMatrix(gold, test)
它给了我同样的错误。
Traceback (most recent call last):
File "<pyshell#23>", line 1, in <module>
cm = nltk.ConfusionMatrix(gold, test)
File "C:\Python27\lib\site-packages\nltk\metrics\confusionmatrix.py", line 46, in __init__
raise ValueError('Lists must have the same length.')
ValueError: Lists must have the same length.
>>>
如果有人可以帮助我如何解除它?
答案 0 :(得分:0)
查看ConfusionMatrix的来源
def __init__(self, reference, test, sort_by_count=False):
"""
Construct a new confusion matrix from a list of reference
values and a corresponding list of test values.
:type reference: list
:param reference: An ordered list of reference values.
:type test: list
:param test: A list of values to compare against the
corresponding reference values.
:raise ValueError: If ``reference`` and ``length`` do not have
the same length.
"""
if len(reference) != len(test):
raise ValueError('Lists must have the same length.')
http://www.nltk.org/_modules/nltk/metrics/confusionmatrix.html
我不打算通过您的代码,因为我使用NLTK已经有一段时间了,但只是尝试打印您的黄金标准,预测数组并确保它们具有相同的长度
答案 1 :(得分:0)
对于这两个错误生成示例,您都会遇到长度不匹配错误:
你可以通过以下方式修剪黄金:
gold_full = tag_list(brown_a)
gold = gold_full[:len(test)]
假设金标准也将大于测试;否则你可以添加条件吗?