SO有一个关于如何从模型中检查参数总数的答案:
pytorch_total_params = sum(p.numel() for p in model.parameters())
但是,如何检查来自的参数总数
state_dict
?
state_dict = torch.load(model_path, map_location='cpu')
?
答案 0 :(得分:3)
您可以在state_dict中计算保存的条目数:
export const updateCardMutation = {
async updateCard(_, { id, patch }): Promise<Card> {
const repository = getRepository(Card);
const card = await repository.findOne({ id });
const result = await repository.update(id, {...patch}); // here
return {
...card,
...patch,
};
},
但是,这里有一个障碍:state_dict同时存储parameters 和 persistent buffers(例如BatchNorm的均值和var)。除了state_dict本身之外,没有任何方法(AFAIK)告诉它们,您需要将它们加载到模型中,并使用sum(p.numel() for p in state_dict.values())
仅计算参数。
例如,如果您结帐resnet50
sum(p.numel() for p in model.parameters()
结果
from torchvision.models import resnet50 model = resnet50(pretrained=True) state_dict = torch.load('~/.torch/models/resnet50-19c8e357.pth') num_parameters = sum(p.numel() for p in model.parameters()) num_state_dict = sum(p.numel() for p in state_dict.values()) print('num parameters = {}, stored in state_dict = {}, diff = {}'.format(num_parameters, num_state_dict, num_state_dict - num_parameters))
如您所见,两个值之间可能会有很大的差距。