是否可以查看火炬预训练网络的代码

时间:2018-04-28 06:02:06

标签: torch pytorch pre-trained-model

如果你在阅读标题时想到这样的菜鸟 - 是的,我是。

我用谷歌搜索但没有找到一个指南,让我可以查看预先训练过的火炬神经网络是如何设计/编码的。我已经下载了预先训练好的网络(文件格式.t7),我安装了火炬。任何人都可以帮助我查看它是如何编码的(使用了什么尺寸的过滤器,使用的参数等)?

可能不是谷歌,因为它不可能吗?很乐意回答您的任何其他问题,或者有什么不清楚。

谢谢。

1 个答案:

答案 0 :(得分:0)

我认为无法获得底层代码。但是你可以通过使用print来获得模型的摘要,其中包括图层和主要参数。

model = SumModel(vocab_size=vocab_size, hiddem_dim=hidden_dim, batch_size=batch_size)
# saving model
torch.save(model, 'test_model.save')
# print summary of original
print(' - original model summary:')
print(model)
print()

# load saved model
loaded_model = torch.load('test_model.save')
# print summary of loaded model
print(' - loaded model summary:')
print(loaded_model)

这将输出如下所示的摘要。

  - original model summary:
SumModel(
  (word_embedding): Embedding(530734, 128)
  (encoder): LSTM(128, 128, batch_first=True)
  (decoder): LSTM(128, 128, batch_first=True)
  (output_layer): Linear(in_features=128, out_features=530734, bias=True)
)

 - loaded model summary:
SumModel(
  (word_embedding): Embedding(530734, 128)
  (encoder): LSTM(128, 128, batch_first=True)
  (decoder): LSTM(128, 128, batch_first=True)
  (output_layer): Linear(in_features=128, out_features=530734, bias=True)
)

使用Pytorch 0.4.0进行测试

正如您所看到的,原始模型和加载模型的输出都是一致的。

我希望这会有所帮助。