我正在使用pytorch来训练部分网络。例如,我有一个模型结构
hidden1 = Layer1(x)
hidden2 = Layer2(hidden1)
out = Layer3(hidden2)
如果我只想训练Layer3,我可以使用
hidden1 = Layer1(x)
hidden2 = Layer2(hidden1).detach()
out = Layer3(hidden2)
但是,这次我只想训练Layer1。我该如何实现?谢谢。
答案 0 :(得分:1)
detach
不会真正“冻结”您的图层。
如果您不想训练图层,则应改用requires_grad=False
。
例如:
hidden2.weight.requires_grad = False
hidden2.bias.requires_grad = False
然后解冻,请对requires_grad=True
执行相同的操作。