我正在使用fast-ai库来训练IMDB评论数据集的样本。我的目标是实现情感分析,我只想从一个小的数据集开始(这个数据集包含1000个IMDB评论)。我已经使用this tutorial在VM中训练了模型。
我保存了data_lm
和data_clas
模型,然后保存了编码器ft_enc
,然后保存了分类器学习者sentiment_model
。然后,我从VM中获得了这4个文件,并将它们放入我的计算机中,并希望使用这些经过预训练的模型来对情绪进行分类。
这就是我所做的:
# Use the IMDB_SAMPLE file
path = untar_data(URLs.IMDB_SAMPLE)
# Language model data
data_lm = TextLMDataBunch.from_csv(path, 'texts.csv')
# Sentiment classifier model data
data_clas = TextClasDataBunch.from_csv(path, 'texts.csv',
vocab=data_lm.train_ds.vocab, bs=32)
# Build a classifier using the tuned encoder (tuned in the VM)
learn = text_classifier_learner(data_clas, AWD_LSTM, drop_mult=0.5)
learn.load_encoder('ft_enc')
# Load the trained model
learn.load('sentiment_model')
之后,我想使用该模型来预测句子的情感。执行此代码时,我遇到了以下错误:
RuntimeError: Error(s) in loading state_dict for AWD_LSTM:
size mismatch for encoder.weight: copying a param with shape torch.Size([8731, 400]) from checkpoint, the shape in current model is torch.Size([8888, 400]).
size mismatch for encoder_dp.emb.weight: copying a param with shape torch.Size([8731, 400]) from checkpoint, the shape in current model is torch.Size([8888, 400]).
回溯为:
Traceback (most recent call last):
File "C:/Users/user/PycharmProjects/SentAn/mainApp.py", line 51, in <module>
learn = load_models()
File "C:/Users/user/PycharmProjects/SentAn/mainApp.py", line 32, in load_models
learn.load_encoder('ft_enc')
File "C:\Users\user\Desktop\py_code\env\lib\site-packages\fastai\text\learner.py", line 68, in load_encoder
encoder.load_state_dict(torch.load(self.path/self.model_dir/f'{name}.pth'))
File "C:\Users\user\Desktop\py_code\env\lib\site-packages\torch\nn\modules\module.py", line 769, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
因此,在加载编码器时会发生错误。但是,我也尝试删除了load_encoder
行,但是下一行learn.load('sentiment_model')
发生了同样的错误。
我在fast-ai论坛中进行了搜索,发现其他人也遇到了此问题,但没有找到解决方案。在this post中,用户说这可能与不同的预处理有关,尽管我不知道为什么会这样。
有人知道我在做什么错吗?
答案 0 :(得分:2)
似乎data_clas和data_lm的词汇量不同。我猜问题是由data_clas和data_lm中使用的不同预处理引起的。要检查我的猜测,我只是用
data_clas.vocab.itos = data_lm.vocab.itos
在下一行之前
learn_c = text_classifier_learner(data_clas,AWD_LSTM,drop_mult = 0.3)
这已修复错误。