如何计算PyTorch模型中的参数总数?类似于Keras的model.count_params()
。
答案 0 :(得分:44)
PyTorch没有像Keras那样计算参数总数的功能,但是可以对每个参数组的元素数量求和:
pytorch_total_params = sum(p.numel() for p in model.parameters())
如果您只想计算 trainable 参数:
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
在PyTorch论坛上以answer为灵感的答案。
注意:我是answering my own question。如果有人有更好的解决方案,请与我们分享。
答案 1 :(得分:10)
要获取Keras之类的每个层的参数计数,PyTorch具有model.named_paramters(),它返回参数名称和参数本身的迭代器。
这里是一个例子:
from prettytable import PrettyTable
def count_parameters(model):
table = PrettyTable(["Modules", "Parameters"])
total_params = 0
for name, parameter in model.named_parameters():
if not parameter.requires_grad: continue
param = parameter.numel()
table.add_row([name, param])
total_params+=param
print(table)
print(f"Total Trainable Params: {total_params}")
return total_params
count_parameters(net)
输出看起来像这样:
+-------------------+------------+
| Modules | Parameters |
+-------------------+------------+
| embeddings.weight | 922866 |
| conv1.weight | 1048576 |
| conv1.bias | 1024 |
| bn1.weight | 1024 |
| bn1.bias | 1024 |
| conv2.weight | 2097152 |
| conv2.bias | 1024 |
| bn2.weight | 1024 |
| bn2.bias | 1024 |
| conv3.weight | 2097152 |
| conv3.bias | 1024 |
| bn3.weight | 1024 |
| bn3.bias | 1024 |
| lin1.weight | 50331648 |
| lin1.bias | 512 |
| lin2.weight | 265728 |
| lin2.bias | 519 |
+-------------------+------------+
Total Trainable Params: 56773369
答案 2 :(得分:1)
如果您想在不实例化模型的情况下计算每层的权重和偏差的数量,则可以简单地加载原始文件并像这样循环遍历生成的JobBuilder.newJob(ReadTest.class)
:
collections.OrderedDict
您会得到类似的东西
import torch
tensor_dict = torch.load('model.dat', map_location='cpu') # OrderedDict
tensor_list = list(tensor_dict.items())
for layer_tensor_name, tensor in tensor_list:
print('Layer {}: {} elements'.format(layer_tensor_name, torch.numel(tensor)))
答案 3 :(得分:1)
如果要避免重复计算共享参数,可以使用torch.Tensor.data_ptr
。例如:
sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())
这是一个更为冗长的实现,其中包括一个用于过滤掉不可训练参数的选项:
def numel(m: torch.nn.Module, only_trainable: bool = False):
"""
returns the total number of parameters used by `m` (only counting
shared parameters once); if `only_trainable` is True, then only
includes parameters with `requires_grad = True`
"""
parameters = m.parameters()
if only_trainable:
parameters = list(p for p in parameters if p.requires_grad)
unique = dict((p.data_ptr(), p) for p in parameters).values()
return sum(p.numel() for p in unique)
答案 4 :(得分:1)
您可以使用torchsummary
做同样的事情。只是两行代码。
from torchsummary import summary
print(summary(model, (input_shape)))