在keras中,创建模型后,我们可以使用model.input_shape
,model.output_shape
看到其输入,输出形状。对于权重和配置,我们可以分别使用model.get_weights()
和model.get_config()
。pytorch有哪些类似的替代方法?检查pytorch模型还需要了解其他功能吗?
作为总结,我知道在pytorch中我们打印模型print(model)
,但是所提供的信息少于model.summary()
。 pytorch有更好的摘要吗?
答案 0 :(得分:1)
pytorch中没有“ model.summary()”方法。您需要使用内建的方法和模型的字段。
例如,我定制了inception_v3模型。为了获得信息,我需要使用其他许多不同的字段。例如:
IN:
print(model) # print network architecture
输出
Inception3(
(Conv2d_1a_3x3): BasicConv2d(
(conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
(bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
)
(Conv2d_2a_3x3): BasicConv2d(
(conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
(bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
)
(Conv2d_2b_3x3): BasicConv2d(
(conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
)
(Conv2d_3b_1x1): BasicConv2d(
(conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(80, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
)
(Conv2d_4a_3x3): BasicConv2d(
(conv): Conv2d(80, 192, kernel_size=(3, 3), stride=(1, 1), bias=False)
...
IN:
for i in model.state_dict().keys():
print(i)
#print keys of dict with values of learned weights, bias, parameters
输出:
Conv2d_1a_3x3.conv.weight
Conv2d_1a_3x3.bn.weight
Conv2d_1a_3x3.bn.bias
Conv2d_1a_3x3.bn.running_mean
Conv2d_1a_3x3.bn.running_var
Conv2d_1a_3x3.bn.num_batches_tracked
Conv2d_2a_3x3.conv.weight
Conv2d_2a_3x3.bn.weight
Conv2d_2a_3x3.bn.bias
Conv2d_2a_3x3.bn.running_mean
...
因此,如果我想获取Conv2d_1a_3x3处CNN层的权重,我会寻找键“ Conv2d_1a_3x3.conv.weight”:
print("model.save_dict()["Conv2d_1a_3x3.conv.weight"])
输出:
tensor([[[[-0.2103, -0.3441, -0.0344],
[-0.1420, -0.2520, -0.0280],
[ 0.0736, 0.0183, 0.0381]],
[[ 0.1417, 0.1593, 0.0506],
[ 0.0828, 0.0854, 0.0186],
[ 0.0283, 0.0144, 0.0508]],
...
如果要从优化程序中查看使用的超参数:
optimizer.param_groups
OUT:
[{'dampening': 0,
'lr': 0.01,
'momentum': 0.01,
'nesterov': False,
'params': [Parameter containing:
tensor([[[[-0.2103, -0.3441, -0.0344],
[-0.1420, -0.2520, -0.0280],
[ 0.0736, 0.0183, 0.0381]],
...