在pytorch模型中保存嵌入层

时间:2019-12-30 10:11:37

标签: pytorch

我有这个模型:

class model(nn.Module):
    def __init__(self):
      super().__init__()
      self.conv1 = nn.Conv2d(in_channels=12,out_channels=64,kernel_size=3,stride= 1,padding=1)
      # self.conv2 = nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,stride= 1,padding=1)
      self.fc1 = nn.Linear(24576, 128)
      self.bn = nn.BatchNorm1d(128)
      self.dropout1 = nn.Dropout2d(0.5)
      self.fc2 = nn.Linear(128, 10)
      self.fc3 = nn.Linear(10, 3)

    def forward(self, x):
      x = F.relu(self.conv1(x))
      # x = F.relu(self.conv2(x))
      x = F.max_pool2d(x, (2,2))
      # print(x.shape)
      x = x.view(-1,24576)
      x = self.bn(F.relu(self.fc1(x)))
      x = self.dropout1(x)
      embeding_stage = F.relu(self.fc2(x))
      x = self.fc3(embeding_stage)

      return x

并且我想保存embeding_stage层,就像我在此处保存模型一样:

model = model()
torch.save(model.state_dict(), 'C:\project\count_speakers\model_pytorch.h5')

谢谢, 阿亚尔

1 个答案:

答案 0 :(得分:0)

我不确定我是否理解“保存embedding_stage层”的含义,但是如果要保存fc2或fc3或其他内容,则可以使用torch.save()来实现。
例如:要保存fc3:torch.save(model.fc3),'C:\...\fc3.pt')

编辑:

Op希望获得embedding_stage的输出。
您可以通过几种方式来做到这一点:

  • 使用model.load_state_dict(torch.load('C:\...\model_pytorch.h5'))加载模型 然后model = nn.Sequential(*list(model.children())[:-1])。模型的输出是embeding_stage。

  • 制作一个Model2(nn.Module),与您的第一个Model()完全相同,但是将return x中的def forward(self, x):替换为return embeding_stage。然后像这样将第一个模型的状态加载到第二个模型中:model2.load_state_dict(torch.load('C:\...\model_pytorch.h5'))
    像这样的fc3将被加载,但不会被使用。 model2(x)的输出将为embeding_stage。