Pytorch:如何获得图中的所有张量

时间:2018-12-21 02:29:36

标签: python deep-learning pytorch

我想访问图的所有张量实例。例如,我可以检查张量是否分离或我可以检查其大小。可以在tensorflow中完成。

想要图形的可视化。

1 个答案:

答案 0 :(得分:1)

您可以在运行时访问整个计算图。为此,您可以使用钩子。这些是插入到 nn.Module 上的函数,用于推理和反向传播。

在推理时,您可以使用 register_forward_hook 插入钩子。对于反向传播,您可以使用 register_backward_hook(注意:在 1.8.0 版本中,此函数将被弃用,而支持 register_full_backward_hook)。

通过这两个函数,您基本上可以访问计算图上的任何张量。是否要打印所有张量、打印形状,甚至插入断点进行调查,这完全取决于您。

这是一个可能的实现:

def forward_hook(module, input, output):
    # ...

参数 input 由 PyTorch 作为 元组 传递,并将包含传递给挂钩模块的转发函数的所有参数。

def backward_hook(module, grad_input, grad_output):
    # ...

对于后向钩子,grad_inputgrad_output 都是 元组,并且会根据模型的层具有不同的形状。

然后您可以将这些回调挂接到任何现有的 nn.Module。例如,您可以遍历模型中的所有子模块:

for module in model.children():
    module.register_forward_hook(forward_hook)
    module.register_backward_hook(backward_hook)

要获取模块的名称,您可以将钩子包裹起来以包含名称并在模型的 named_modules 上循环:

def forward_hook(name):
    def hook(module, x, y):
        print('%s: %s -> %s' % (name, list(x[0].size()), list(y.size())))
    return hook

for name, module in model.named_children():
    module.register_forward_hook(forward_hook(name))

可以在推理中打印以下内容:

fc1: [1, 100] -> [1, 10]
fc2: [1, 10] -> [1, 5]
fc3: [1, 5] -> [1, 1]

正如我所说,向后传球有点复杂。我只能建议您探索和试验pdb

def backward_hook(module, grad_input, grad_output):
    pdb.set_trace()

至于模型的参数,您可以通过调用 module.parameters 轻松访问两个钩子中给定模块的参数。这将返回一个生成器。

我只能祝你在探索你的模型时好运!