Pytorch-使用现有模型的零件或片段

时间:2018-12-19 19:39:31

标签: python keras pytorch

我有一个关于Pytorch的“如何做”问题。

假设我在Pytorch中有两个模型,并且我想在第二个模型中输入一个来自第一个模型中间的输出。 但是第一个模型不是连续的,它非常复杂,并且可能具有内部分支。玩具示例:

Model description

  

如何制作模型1的版本/子类/副本,同时输出最终张量和中间张量?

到目前为止我能做什么?

到目前为止,我发现的唯一方法是了解整个模型及其所有forward连接,因此我精确地复制了forward方法并使其输出所需的中间张量。

但是这非常麻烦,因为我真的不想以高复杂度重建整个模型。这太容易出错了。

有没有更简单的方法,例如在Keras中完成的方法?

看看在Keras中完成它多么容易

只需获取中间层的输出并将其作为基于model1的新模型的输出即可:

middleTensor = model1.get_layer(layer_name).output #or model1.layers[i]
newModel1 = Model(model1.input, [middleTensor,model1.output])

现在我有一个模型1的版本,可以输出最终张量和中间张量,而无需进行任何深入了解或复制它。

  

pytorch中有类似的方法吗?

0 个答案:

没有答案