pytorch中的暹罗神经网络

时间:2018-12-16 15:57:10

标签: python pytorch

如何在PyTorch中实现暹罗神经网络?

  

什么是暹罗神经网络?暹罗神经网络由两个相同神经网络组成,每个神经网络只有一个输入。 Identical 表示两个神经网络具有完全相同的体系结构并共享相同的权重。

enter image description here

1 个答案:

答案 0 :(得分:3)

在PyTorch中实现暹罗神经网络就像在不同的输入上两次调用网络函数一样简单。

mynet = torch.nn.Sequential(
        nn.Linear(10, 512),
        nn.ReLU(),
        nn.Linear(512, 2))
...
output1 = mynet(input1)
output2 = mynet(input2)
...
loss.backward()

调用loss.backwad()时,PyTorch会自动对来自mynet的两次调用的梯度求和。

您可以找到完整的示例here