如何在其他地方使用bert预训练模型?

时间:2020-07-16 15:33:30

标签: python tensorflow pytorch pre-trained-model bert-language-model

我遵循这门课程https://www.coursera.org/learn/sentiment-analysis-bert的有关建立情绪分析的预训练模型的课程。在整理期间,他们在每个时期使用torch.save(model.state_dict(), f'BERT_ft_epoch{epoch}.model')保存了模型。现在,我想在其他地方使用这些模型之一(显然是最好的模型),例如,用户可以将推文粘贴为输入并获得作者的情感。但是我不知道如何加载模型并进行预测,这是我尝试过的方法:

import torchvision.models as models
import torch

model = models.resnet101(pretrained=False)
model.load_state_dict(torch.load('Models/BERT_ft_epoch15.model'), strict=False)
model_ft.eval()
output = model_ft(input) #input is a tweets list

我收到此错误:TypeError: conv2d(): argument 'input' (position 1) must be Tensor, not list

2 个答案:

答案 0 :(得分:2)

resnet101BERT是两个完全不同的模型。您无法将预训练的BERT模型加载到resnet中。

答案 1 :(得分:2)

如何使用Pytorch定义,初始化,保存和加载模型。

初始化模型。这是通过继承类nn.Module来完成的,请考虑简单的两层模型:

import torch
import torch.nn as nn

class Model(nn.Module)
    def __init__(self, input_size=128, output_size=10):
        super(Model).__init__()
    
        self.layer1 = nn.Sequetial(nn.Linear(input_size, 64), nn.LeakyReLU())
        self.layer2 = nn.Linear(64, output_size)
    
    def forward(self, x):
        y = self.layer2(self.layer1(x))
        return y

首先在__init__()初始化模型的各层,然后在forward()中指定正向传递的操作。您可以在那里发挥创造力,只记得使用pytorch可区分的操作。

您可以通过创建新类的实例来初始化模型:

model = Model() # brand new instance!

训练完模型后,您要保存它:

import torch
model = Model(128, 10) # initialization

torch.save(model.state_dict, 'model.pt') # saving state dict

您不是要在此处保存模型,而是要保存state_dict,这是一个有序词典,其中包含模型的所有权重和偏差以及其他参数。我们保存state_dict而不是直接保存模型的原因可以在文档(https://pytorch.org/tutorials/beginner/saving_loading_models.html)中找到。现在,只需考虑最佳做法即可。

最后,我们得出如何加载模型。您必须先初始化模型,然后从磁盘加载state_dict

model = Model(128, 10) # model initialization
model.load_state_dict('model.pt')
model.eval() # put the model in inference mode

请注意,当我们保存state_dict时,我们可能还会保存优化器和用于反向传播的图形。这对于检查点培训并在以后的阶段恢复培训很有用。

    # in the training loop
    torch.save({"epoch": epoch,
                "model": model.state_dict,
                "optim": optim.state_dict,
                "loss": loss}, f'checkpoint{epoch}.pt')

我希望为您画一幅清晰的图画=)