NLTK混淆矩阵

时间:2016-03-28 02:27:51

标签: python machine-learning nltk

我试图从NLTK写出混淆矩阵。 我尝试了以下示例,它运行良好。

>>> import nltk
>>> from nltk.metrics import*
>>> from nltk.corpus import brown
>>> brown_a = nltk.corpus.brown.tagged_sents()[:300]
>>> def tag_list(tagged_sents):
    return [tag for sent in tagged_sents for (word, tag) in sent]

>>> tagger = nltk.UnigramTagger(brown_a)
>>> gold = tag_list(brown_a)
>>> def apply_tagger(tagger, corpus):
    return [tagger.tag(nltk.tag.untag(sent)) for sent in corpus]
>>> test = tag_list(apply_tagger(tagger, brown_a)
>>> cm = nltk.ConfusionMatrix(gold, test)
>>> print cm.pretty_format(show_percents=False,values_in_chart=True,truncate=5,sort_by_count=True)

但如果我按照以下方式给出tesset,

>>> tests=nltk.corpus.brown.tagged_sents()[300:400]
>>> test = tag_list(apply_tagger(tagger, tests))
>>> cm = nltk.ConfusionMatrix(gold, test)

正在生成错误,

Traceback (most recent call last):
  File "<pyshell#12>", line 1, in <module>
    cm = nltk.ConfusionMatrix(gold, test)
  File "C:\Python27\lib\site-packages\nltk\metrics\confusionmatrix.py", line 46, in __init__
    raise ValueError('Lists must have the same length.')
ValueError: Lists must have the same length.

即使我尝试使用相同长度的测试集,

>>> test1=nltk.corpus.brown.tagged_sents()[700:1000]
>>> test = tag_list(apply_tagger(tagger, test1))
>>> cm = nltk.ConfusionMatrix(gold, test)

它给了我同样的错误。

Traceback (most recent call last):
  File "<pyshell#23>", line 1, in <module>
    cm = nltk.ConfusionMatrix(gold, test)
  File "C:\Python27\lib\site-packages\nltk\metrics\confusionmatrix.py", line 46, in __init__
    raise ValueError('Lists must have the same length.')
ValueError: Lists must have the same length.
>>>

如果有人可以帮助我如何解除它?

2 个答案:

答案 0 :(得分:0)

查看ConfusionMatrix的来源

def __init__(self, reference, test, sort_by_count=False):
    """
    Construct a new confusion matrix from a list of reference
    values and a corresponding list of test values.

    :type reference: list
    :param reference: An ordered list of reference values.
    :type test: list
    :param test: A list of values to compare against the
        corresponding reference values.
    :raise ValueError: If ``reference`` and ``length`` do not have
        the same length.
    """
    if len(reference) != len(test):
        raise ValueError('Lists must have the same length.')

http://www.nltk.org/_modules/nltk/metrics/confusionmatrix.html

我不打算通过您的代码,因为我使用NLTK已经有一段时间了,但只是尝试打印您的黄金标准,预测数组并确保它们具有相同的长度

答案 1 :(得分:0)

对于这两个错误生成示例,您都会遇到长度不匹配错误:

  • 示例1:len(test)= 2459并且len(gold)= 6642
  • 示例2:len(test)= 6261,len(gold)= 6642

你可以通过以下方式修剪黄金:

gold_full = tag_list(brown_a)
gold = gold_full[:len(test)]

假设金标准也将大于测试;否则你可以添加条件吗?