在火炬中加载预训练模型

时间:2019-10-09 17:26:42

标签: deep-learning nlp pytorch

首先,我想道歉这个问题听起来很愚蠢,但我是深度学习的新手。有人可以向我解释以下用于在PyTorch中加载预训练模型的代码行吗?

# Retrieving model parameters from checkpoint.
vocab_size = checkpoint["model"]["_word_embedding.weight"].size(0)
embedding_dim = checkpoint["model"]['_word_embedding.weight'].size(1)
hidden_size = checkpoint["model"]["_projection.0.weight"].size(0)
num_classes = checkpoint["model"]["_classification.4.weight"].size(0)

我无法理解以上文字中的投影,重量,分类,尺寸(0),尺寸(1)。

1 个答案:

答案 0 :(得分:1)

import torch
import torch.nn as nn


class Model(nn.Module):

    def __init__(self):
        super(Model, self).__init__()

        vocab_size = 10000
        embed_size = 100
        # word embedding layer
        self._word_embedding = nn.Embedding(vocab_size, embed_size)
        # linear transformation layers (no bias)
        self._projection = nn.ModuleList([nn.Linear(100, 50, bias=False)
                                          for i in range(2)])
        # linear transformation layers (no bias)
        self._classification = nn.ModuleList([nn.Linear(50, 50, bias=False)
                                              for i in range(4)])

    def forward(self):
        return


model = Model()
checkpoint = {
    'model': model.state_dict()  # OrderedDict
}

# _word_embedding.weight --> torch.Size([10000, 100])
# _projection.0.weight --> torch.Size([50, 100])
# _projection.1.weight --> torch.Size([50, 100])
# _classification.0.weight --> torch.Size([50, 50])
# _classification.1.weight --> torch.Size([50, 50])
# _classification.2.weight --> torch.Size([50, 50])
# _classification.3.weight --> torch.Size([50, 50])

for name, param in checkpoint['model'].items():
    print(name, '-->', param.size()) # see above

# similarly, we can print as follows
print(checkpoint["model"]["_word_embedding.weight"].size(0)) # 10000
print(checkpoint["model"]["_word_embedding.weight"].size(1)) # 100
print(checkpoint["model"]["_projection.0.weight"].size(0)) # 50
print(checkpoint["model"]["_classification.0.weight"].size(0)) # 50

准备了一个示例来帮助您理解这四行的含义。

  

我无法理解以上文字中的投影,重量,分类,尺寸(0),尺寸(1)。

  • 投影:神经网络层
  • 分类:神经网络层
  • 权重:各个NN层的权重矩阵
  • size(0):权重矩阵的第一维的大小
  • size(1):权重矩阵第二维的大小