我在官方的pytorch教程中得到了这个nn结构:
输入 - > conv2d - > relu - > maxpool2d - > conv2d - > relu - > maxpool2d - >查看 - >线性 - > relu - >线性 - > relu - >线性 - > MSELoss - >损失
然后是一个如何使用Variable中的内置.grad_fn向后关注grad的示例。
# Eg:
print(loss.grad_fn) # MSELoss
print(loss.grad_fn.next_functions[0][0]) # Linear
print(loss.grad_fn.next_functions[0][0].next_functions[0][0]) # ReLU
所以我认为我可以通过粘贴next_function [0] [0] 9次来达到Conv2d的grad对象,因为给定的例子但我从索引中得到了错误元组。那么如何正确索引这些backprop对象呢?
答案 0 :(得分:3)
在完成本教程中的以下操作后,在PyTorch CNN tutorial中:
output = net(input)
target = torch.randn(10) # a dummy target, for example
target = target.view(1, -1) # make it the same shape as output
criterion = nn.MSELoss()
loss = criterion(output, target)
print(loss)
以下代码段将显示完整图形:
def print_graph(g, level=0):
if g == None: return
print('*'*level*4, g)
for subg in g.next_functions:
print_graph(subg[0], level+1)
print_graph(loss.grad_fn, 0)
答案 1 :(得分:0)
尝试运行
print(loss.grad_fn.next_functions[0][0].next_functions)
您将看到这给出了一个包含三个元素的数组。实际上,这是您要选择的[1] [0]元素,否则,您将获得累积的Grad,并且不能再超出此范围。深入研究时,您会发现可以一路穿越网络。例如,尝试运行:
print(loss.grad_fn.next_functions[0][0].next_functions[1][0].next_functions[0][0].next_functions[1][0].next_functions[0][0].next_functions[1][0].next_functions[0][0].next_functions[0][0].next_functions[0][0].next_functions)
先运行.next_functions而不索引,然后查看需要选择哪个元素才能到达nn的下一层。