损失减少但f1得分保持不变

时间:2019-07-18 09:22:47

标签: pytorch ner crf

模型损失减少,但模型的性能(例如F1得分)没有增加。

我想微调来自Facebook的相关语言模型XLM以执行NER任务,因此我将BiLSTM和CRF链接在一起。

这是我的模型架构。整个代码存储库已上传到github https://github.com/stefensa/XLM_NER

class XLM_BiLSTM_CRF(nn.Module):
    def __init__(self, config, num_labels, params, dico, reloaded):
        super().__init__()
        self.config = config
        self.num_labels = num_labels
        self.batch_size = config.batch_size
        self.hidden_dim = config.hidden_dim

        self.xlm = TransformerModel(params, dico, True, True)
        self.xlm.eval()
        self.xlm.load_state_dict(reloaded['model'])

        self.lstm = nn.LSTM(config.embedding_dim, config.hidden_dim // 2,
                            num_layers=1, bidirectional=True)
        self.dropout = nn.Dropout(config.dropout)
        self.classifier = nn.Linear(config.hidden_dim, config.num_class)
        self.apply(self.init_bert_weights)
        self.crf = CRF(config.num_class)

    def forward(self, word_ids, lengths, langs=None, causal=False):
        sequence_output = self.xlm('fwd', x=word_ids, lengths=lengths, causal=False).contiguous()
        sequence_output, _ = self.lstm(sequence_output)
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)
        return self.crf.decode(logits)

    def log_likelihood(self, word_ids, lengths, tags):
        sequence_output = self.xlm('fwd', x=word_ids, lengths=lengths, causal=False).contiguous()
        sequence_output, _ = self.lstm(sequence_output)
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)
        return - self.crf(logits, tags.transpose(0,1))

    def init_bert_weights(self, module):
        """ Initialize the weights.
        """
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

这是我模型的初始状态。 enter image description here 这是我模型的第9个时代。指标不变。 enter image description here

谁能解决我的问题?只是让我困惑了两天。

0 个答案:

没有答案