Fastai学习者未加载

时间:2019-02-27 20:39:16

标签: machine-learning model pytorch resnet fast-ai

因此,我尝试使用以下方式加载模型:

learn = create_cnn(data, models.resnet50, lin_ftrs=[2048], metrics=accuracy) 
learn.clip_grad();
learn.load(f'{name}-stage-2.1')

但是我收到以下错误

RuntimeError: Error(s) in loading state_dict for Sequential:
size mismatch for 1.8.weight: copying a param with shape torch.Size([5004, 2048]) from checkpoint, the shape in current model is torch.Size([4542, 2048]).
size mismatch for 1.8.bias: copying a param with shape torch.Size([5004]) from checkpoint, the shape in current model is torch.Size([4542]).

唯一不同的是,我添加了stage-2.1模型中不存在的随机验证分组,当我删除了该分组并且未设置stage-2.1的验证集时,受过训练的一切顺利。

发生了什么事?

2 个答案:

答案 0 :(得分:3)

使用cnn_learner方法和最新的Pytorch和最新的FastAI。有breaking change和间断,所以您现在受苦了。

fastai网站上有许多示例,例如this one

learn = cnn_learner(data, models.resnet50, metrics=accuracy)

答案 1 :(得分:2)

实际上是从检查点开始的torch.Size([5004,2048]),当前模型中的形状是torch.Size([4542,2048]) 您必须更改它。