Pytorch网络参数计算

时间:2018-01-23 03:12:23

标签: deep-learning conv-neural-network pytorch

有人可以告诉我有关网络参数(10)的计算方法吗?提前谢谢。

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(x.size()[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()
print(net)
print(len(list(net.parameters())))

输出:

Net(
  (conv1): Conv2d (1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d (6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120)
  (fc2): Linear(in_features=120, out_features=84)
  (fc3): Linear(in_features=84, out_features=10)
)
10

最佳, 扎克

2 个答案:

答案 0 :(得分:3)

PyTorch中的大多数层模块(例如,Linear,Conv2d等)将参数分组为特定类别,例如权重和偏差。您网络中的五个图层实例中的每一个都有一个" weight"和"偏见"参数。这就是为什么" 10"打印出来。

当然,所有这些"重量"和"偏见"字段包含许多参数。例如,您的第一个完全连接的图层self.fc1包含16 * 5 * 5 * 120 = 48000个参数。所以len(params)并没有告诉你网络中的参数数量 - 它只给出了"分组的总数"网络中的参数。

答案 1 :(得分:3)

因为比尔已经回答了为什么" 10"打印,我只是共享一个代码片段,您可以使用它来找出与您网络中每个图层相关的参数数量。

def count_parameters(model):
    total_param = 0
    for name, param in model.named_parameters():
        if param.requires_grad:
            num_param = numpy.prod(param.size())
            if param.dim() > 1:
                print(name, ':', 'x'.join(str(x) for x in list(param.size())), '=', num_param)
            else:
                print(name, ':', num_param)
            total_param += num_param
    return total_param

使用以上功能如下。

print('number of trainable parameters =', count_parameters(net))

输出:

conv1.weight : 6x1x5x5 = 150
conv1.bias : 6
conv2.weight : 16x6x5x5 = 2400
conv2.bias : 16
fc1.weight : 120x400 = 48000
fc1.bias : 120
fc2.weight : 84x120 = 10080
fc2.bias : 84
fc3.weight : 10x84 = 840
fc3.bias : 10
number of trainable parameters = 61706