从pytorch中的state_dict检查参数数量

时间:2019-09-01 06:43:03

标签: pytorch

SO有一个关于如何从模型中检查参数总数的答案: pytorch_total_params = sum(p.numel() for p in model.parameters())

但是,如何检查来自的参数总数 state_dict

state_dict = torch.load(model_path, map_location='cpu')

1 个答案:

答案 0 :(得分:3)

您可以在state_dict中计算保存的条目数:

export const updateCardMutation = {
async updateCard(_, { id, patch }): Promise<Card> {
    const repository = getRepository(Card);
    const card = await repository.findOne({ id });
    const result = await repository.update(id, {...patch}); // here
    return {
        ...card,
        ...patch,
    };
},

但是,这里有一个障碍:state_dict同时存储parameters persistent buffers(例如BatchNorm的均值和var)。除了state_dict本身之外,没有任何方法(AFAIK)告诉它们,您需要将它们加载到模型中,并使用sum(p.numel() for p in state_dict.values()) 仅计算参数。

例如,如果您结帐resnet50

sum(p.numel() for p in model.parameters()

结果

from torchvision.models import resnet50
model = resnet50(pretrained=True)
state_dict = torch.load('~/.torch/models/resnet50-19c8e357.pth')
num_parameters = sum(p.numel() for p in model.parameters())
num_state_dict = sum(p.numel() for p in state_dict.values())
print('num parameters = {}, stored in state_dict = {}, diff = {}'.format(num_parameters, num_state_dict, num_state_dict - num_parameters))

如您所见,两个值之间可能会有很大的差距。