叶节点

时间:2017-07-04 20:37:00

标签: torch pytorch

来自docs

  

requires_grad - 指示变量是否已存在的布尔值   由包含任何变量的子图创建,需要它。能够   仅在叶子变量上更改

  1. 这里的叶子节点是什么意思?叶节点只是输入节点吗?
  2. 如果只能在叶节点上更改,那么如何冻结图层?

1 个答案:

答案 0 :(得分:7)

  1. 图的叶节点是那些不是直接从图中的其他节点计算的节点(即Variables)。例如:

    import torch
    from torch.autograd import Variable
    
    A = Variable(torch.randn(10,10)) # this is a leaf node
    B = 2 * A # this is not a leaf node
    w = Variable(torch.randn(10,10)) # this is a leaf node
    C = A.mm(w) # this is not a leaf node
    

    如果叶节点requires_grad,从它计算的所有后续节点也将自动require_grad。否则,您无法应用链规则来计算requires_grad的叶节点的渐变。这就是为什么requires_grad只能为叶节点设置的原因:对于所有其他节点,可以巧妙地推断它,并且实际上由用于计算这些其他变量的叶节点的设置确定。

  2. 请注意,在典型的神经网络中,所有参数都是叶节点。它们不是从网络中的任何其他Variables计算的。因此,使用requires_grad冻结图层很简单。这里是一个来自PyTorch文档的例子:

    model = torchvision.models.resnet18(pretrained=True)
    for param in model.parameters():
        param.requires_grad = False
    
    # Replace the last fully-connected layer
    # Parameters of newly constructed modules have requires_grad=True by default
    model.fc = nn.Linear(512, 100)
    
    # Optimize only the classifier
    optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)
    

    尽管如此,你真正做的是冻结整个梯度计算(这是你应该做的,因为它避免了不必要的计算)。从技术上讲,您可以保留requires_grad标志,并仅为您想要学习的参数子集定义优化器。