有人可以告诉我有关网络参数(10)的计算方法吗?提前谢谢。
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16*5*5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(x.size()[0], -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
print(net)
print(len(list(net.parameters())))
输出:
Net(
(conv1): Conv2d (1, 6, kernel_size=(5, 5), stride=(1, 1))
(conv2): Conv2d (6, 16, kernel_size=(5, 5), stride=(1, 1))
(fc1): Linear(in_features=400, out_features=120)
(fc2): Linear(in_features=120, out_features=84)
(fc3): Linear(in_features=84, out_features=10)
)
10
最佳, 扎克
答案 0 :(得分:3)
PyTorch中的大多数层模块(例如,Linear,Conv2d等)将参数分组为特定类别,例如权重和偏差。您网络中的五个图层实例中的每一个都有一个" weight"和"偏见"参数。这就是为什么" 10"打印出来。
当然,所有这些"重量"和"偏见"字段包含许多参数。例如,您的第一个完全连接的图层self.fc1
包含16 * 5 * 5 * 120 = 48000
个参数。所以len(params)
并没有告诉你网络中的参数数量 - 它只给出了"分组的总数"网络中的参数。
答案 1 :(得分:3)
因为比尔已经回答了为什么" 10"打印,我只是共享一个代码片段,您可以使用它来找出与您网络中每个图层相关的参数数量。
def count_parameters(model):
total_param = 0
for name, param in model.named_parameters():
if param.requires_grad:
num_param = numpy.prod(param.size())
if param.dim() > 1:
print(name, ':', 'x'.join(str(x) for x in list(param.size())), '=', num_param)
else:
print(name, ':', num_param)
total_param += num_param
return total_param
使用以上功能如下。
print('number of trainable parameters =', count_parameters(net))
输出:
conv1.weight : 6x1x5x5 = 150
conv1.bias : 6
conv2.weight : 16x6x5x5 = 2400
conv2.bias : 16
fc1.weight : 120x400 = 48000
fc1.bias : 120
fc2.weight : 84x120 = 10080
fc2.bias : 84
fc3.weight : 10x84 = 840
fc3.bias : 10
number of trainable parameters = 61706