检查PyTorch模型中的参数总数

时间:2018-03-09 19:55:24

标签: deep-learning pytorch

如何计算PyTorch模型中的参数总数?类似于Keras的model.count_params()

5 个答案:

答案 0 :(得分:44)

PyTorch没有像Keras那样计算参数总数的功能,但是可以对每个参数组的元素数量求和:

pytorch_total_params = sum(p.numel() for p in model.parameters())

如果您只想计算 trainable 参数:

pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

在PyTorch论坛上以answer为灵感的答案

注意:我是answering my own question。如果有人有更好的解决方案,请与我们分享。

答案 1 :(得分:10)

要获取Keras之类的每个层的参数计数,PyTorch具有model.named_paramters(),它返回参数名称和参数本身的迭代器。

这里是一个例子:

from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        param = parameter.numel()
        table.add_row([name, param])
        total_params+=param
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params
    
count_parameters(net)

输出看起来像这样:

+-------------------+------------+
|      Modules      | Parameters |
+-------------------+------------+
| embeddings.weight |   922866   |
|    conv1.weight   |  1048576   |
|     conv1.bias    |    1024    |
|     bn1.weight    |    1024    |
|      bn1.bias     |    1024    |
|    conv2.weight   |  2097152   |
|     conv2.bias    |    1024    |
|     bn2.weight    |    1024    |
|      bn2.bias     |    1024    |
|    conv3.weight   |  2097152   |
|     conv3.bias    |    1024    |
|     bn3.weight    |    1024    |
|      bn3.bias     |    1024    |
|    lin1.weight    |  50331648  |
|     lin1.bias     |    512     |
|    lin2.weight    |   265728   |
|     lin2.bias     |    519     |
+-------------------+------------+
Total Trainable Params: 56773369

答案 2 :(得分:1)

如果您想在不实例化模型的情况下计算每层的权重和偏差的数量,则可以简单地加载原始文件并像这样循环遍历生成的JobBuilder.newJob(ReadTest.class)

collections.OrderedDict

您会得到类似的东西

import torch


tensor_dict = torch.load('model.dat', map_location='cpu') # OrderedDict
tensor_list = list(tensor_dict.items())
for layer_tensor_name, tensor in tensor_list:
    print('Layer {}: {} elements'.format(layer_tensor_name, torch.numel(tensor)))

答案 3 :(得分:1)

如果要避免重复计算共享参数,可以使用torch.Tensor.data_ptr。例如:

sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())

这是一个更为冗长的实现,其中包括一个用于过滤掉不可训练参数的选项:

def numel(m: torch.nn.Module, only_trainable: bool = False):
    """
    returns the total number of parameters used by `m` (only counting
    shared parameters once); if `only_trainable` is True, then only
    includes parameters with `requires_grad = True`
    """
    parameters = m.parameters()
    if only_trainable:
        parameters = list(p for p in parameters if p.requires_grad)
    unique = dict((p.data_ptr(), p) for p in parameters).values()
    return sum(p.numel() for p in unique)

答案 4 :(得分:1)

您可以使用torchsummary做同样的事情。只是两行代码。

from torchsummary import summary

print(summary(model, (input_shape)))