使用张量流的viterbi_decode时出错

时间:2018-06-01 03:32:04

标签: tensorflow ner

我正在使用此github.com/Determined22/zh-NER-TF 我刚刚使用了另一种格式相同的train_data。 代码没有问题,因为当我使用原始train_data运行时它没问题。是什么导致这种情况?

Traceback (most recent call last):
  File "main.py", line 83, in <module>
    model.train(train=train_data, dev=dev_data)
  File "/home/mengyuguang/NER/model.py", line 161, in train
    self.run_one_epoch(sess, train, dev, self.tag2label, epoch, saver)
  File "/home/mengyuguang/NER/model.py", line 221, in run_one_epoch
    label_list_dev, seq_len_list_dev = self.dev_one_epoch(sess, dev)
  File "/home/mengyuguang/NER/model.py", line 256, in dev_one_epoch
    label_list_, seq_len_list_ = self.predict_one_batch(sess, seqs)
  File "/home/mengyuguang/NER/model.py", line 277, in predict_one_batch
    viterbi_seq, _ = viterbi_decode(logit[:seq_len], transition_params)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/crf/python/ops/crf.py", line 333, in viterbi_decode
    trellis[0] = score[0]
IndexError: index 0 is out of bounds for axis 0 with size 0

2 个答案:

答案 0 :(得分:0)

def read_corpus(self, corpus_path):
    data = []
    with open(corpus_path, 'r') as r_file:
        sent_, tag_ = [], []
        for line in r_file:
            line = line.strip()
            if len(line) != 0 and line != '-DOCSTART-':
                ls = line.split('\t')
                char, tag = ls[0], ls[-1]
                sent_.append(char)
                tag_.append(tag)
            else:
                    data.append((sent_, tag_))
                    sent_, tag_ = [], []
        # Bug-fix
        # Here, since the last tuple (sent_, tag_) will be added into data
        # It will case IndexError in viterbi_decode since the sequence_length is 0
        if sent_ and tag_:
            data.append((sent_, tag_))
    self.data = data

答案 1 :(得分:0)

该代码应更改为以下内容:

def read_corpus(corpus_path):
    """
    read corpus and return the list of samples
    :param corpus_path:
    :return: data
    """
    data = []
    with open(corpus_path, encoding='utf-8') as fr:
        lines = fr.readlines()
    sent_, tag_ = [], []
    for line in lines:
        if line != '\n' and line != '\t\n':  #
            [char, label] = line.strip().split()
            sent_.append(char)
            tag_.append(label)
        #else:
        elif sent_ !=[] and tag_ !=[]: # 
            data.append((sent_, tag_))
            sent_, tag_ = [], []

    return data