如何在pytorch中正确使用grad_fn中的next_functions [0] [0]?

时间:2018-03-25 17:32:02

标签: pytorch

我在官方的pytorch教程中得到了这个nn结构:

  

输入    - > conv2d - > relu - > maxpool2d - > conv2d - > relu - > maxpool2d         - >查看 - >线性 - > relu - >线性 - > relu - >线性         - > MSELoss         - >损失

然后是一个如何使用Variable中的内置.grad_fn向后关注grad的示例。

# Eg: 
print(loss.grad_fn)  # MSELoss
print(loss.grad_fn.next_functions[0][0])  # Linear
print(loss.grad_fn.next_functions[0][0].next_functions[0][0])  # ReLU

所以我认为我可以通过粘贴next_function [0] [0] 9次来达到Conv2d的grad对象,因为给定的例子但我从索引中得到了错误元组。那么如何正确索引这些backprop对象呢?

2 个答案:

答案 0 :(得分:3)

在完成本教程中的以下操作后,在PyTorch CNN tutorial中:

output = net(input)
target = torch.randn(10)  # a dummy target, for example
target = target.view(1, -1)  # make it the same shape as output
criterion = nn.MSELoss()

loss = criterion(output, target)
print(loss)

以下代码段将显示完整图形:

def print_graph(g, level=0):
    if g == None: return
    print('*'*level*4, g)
    for subg in g.next_functions:
        print_graph(subg[0], level+1)

print_graph(loss.grad_fn, 0)

答案 1 :(得分:0)

尝试运行

print(loss.grad_fn.next_functions[0][0].next_functions)

您将看到这给出了一个包含三个元素的数组。实际上,这是您要选择的[1] [0]元素,否则,您将获得累积的Grad,并且不能再超出此范围。深入研究时,您会发现可以一路穿越网络。例如,尝试运行:

print(loss.grad_fn.next_functions[0][0].next_functions[1][0].next_functions[0][0].next_functions[1][0].next_functions[0][0].next_functions[1][0].next_functions[0][0].next_functions[0][0].next_functions[0][0].next_functions)

先运行.next_functions而不索引,然后查看需要选择哪个元素才能到达nn的下一层。