使用MNIST数据集Pytorch训练SqueezeNet模型

时间:2018-12-03 11:56:18

标签: python neural-network pytorch mnist torchvision

我想使用MNIST数据集而不是ImageNet数据集训练SqueezeNet 1.1模型。
我可以使用与torchvision.models.squeezenet相同的模型吗?
谢谢!

2 个答案:

答案 0 :(得分:2)

TorchVision仅为SqueezeNet体系结构提供ImageNet数据预训练模型。但是,您可以使用MNIST数据集来训练自己的模型,只需从torchvision.models中获取模型(而不是预先训练的模型)即可。

In [10]: import torchvision as tv

# get the model architecture only; ignore `pretrained` flag
In [11]: squeezenet11 = tv.models.squeezenet1_1()

In [12]: squeezenet11.training   
Out[12]: True

现在,您可以使用此体系结构在MNIST数据上训练模型,这应该不会花费太长时间。


要记住的一种修改是更新MNIST的类数为10。具体来说,应将1000更改为10,并相应地调整内核和步幅。

  (classifier): Sequential(
    (0): Dropout(p=0.5)
    (1): Conv2d(512, 1000, kernel_size=(1, 1), stride=(1, 1))
    (2): ReLU(inplace)
    (3): AvgPool2d(kernel_size=13, stride=1, padding=0)
  )

以下是相关说明:finetuning_torchvision_models-squeezenet

答案 1 :(得分:0)

可以对预训练权重进行初始化,但是由于MNIST图像为28X28像素,因此您会在步幅和内核大小上遇到麻烦。减少很有可能会导致在网处于其下层之前生成(batch_sizex1x1xchannel)特征图,然后将导致错误。