我可以像这样打印出我的pytorch模型的类似Keras的摘要:
from torchsummary import summary
summary(model, input_size)
这会打印出这样的内容:
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 112, 112] 9,408
BatchNorm2d-2 [-1, 64, 112, 112] 128
ReLU-3 [-1, 64, 112, 112] 0
MaxPool2d-4 [-1, 64, 56, 56] 0
但是,为了进行图形优化,我也需要模型的输入名称,例如{'input0':[1, 3, 224, 224]}
。那么如何获得这些输入(和输出)名称?