如何通过忽略钩子来强制torch.jit.trace编译我的模块?

时间:2019-05-21 16:40:32

标签: torch

我有一个包含钩子的模块,我想用jit的跟踪进行编译:

compiled_model = torch.jit.trace(model,  torch.rand(1, 3, 256, 256))

但是我得到了错误:

ValueError: Modules that have hooks assigned can't be compiled

如何强制跟踪忽略钩子?

1 个答案:

答案 0 :(得分:0)

如果要绕过跟踪检查,可以从模型中递归删除所有挂钩。

这可以通过遍历子代来完成:

from collections import OrderedDict
def remove_hooks(model):
    model._backward_hooks = OrderedDict()
    model._forward_hooks = OrderedDict()
    model._forward_pre_hooks = OrderedDict()
    for child in model.children():
        remove_hooks(child)

然后您可以强制编译:

remove_hooks(model)
compiled_model = torch.jit.trace(model,  torch.rand(1, 3, 256, 256))

但是,如果钩子实际上在做真实的工作,并且您想让它们保持跟踪(这就是我的情况),您可以在torch/jit/__init__.py行中注释炬的加薪:

if orig._backward_hooks or orig._forward_hooks or orig._forward_pre_hooks:
    raise ValueError("Modules that have hooks assigned can't be compiled")

它对我有用,我设法编译了一个fastai模型。