打印PyTorch输入和输出名称

时间:2019-02-20 16:12:13

标签: python-3.x pytorch

我可以像这样打印出我的pytorch模型的类似Keras的摘要:

from torchsummary import summary
summary(model, input_size)

这会打印出这样的内容:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0

但是,为了进行图形优化,我也需要模型的输入名称,例如{'input0':[1, 3, 224, 224]}。那么如何获得这些输入(和输出)名称?

0 个答案:

没有答案