我有一个包含钩子的模块,我想用jit的跟踪进行编译:
compiled_model = torch.jit.trace(model, torch.rand(1, 3, 256, 256))
但是我得到了错误:
ValueError: Modules that have hooks assigned can't be compiled
如何强制跟踪忽略钩子?
答案 0 :(得分:0)
如果要绕过跟踪检查,可以从模型中递归删除所有挂钩。
这可以通过遍历子代来完成:
from collections import OrderedDict
def remove_hooks(model):
model._backward_hooks = OrderedDict()
model._forward_hooks = OrderedDict()
model._forward_pre_hooks = OrderedDict()
for child in model.children():
remove_hooks(child)
然后您可以强制编译:
remove_hooks(model)
compiled_model = torch.jit.trace(model, torch.rand(1, 3, 256, 256))
但是,如果钩子实际上在做真实的工作,并且您想让它们保持跟踪(这就是我的情况),您可以在torch/jit/__init__.py
行中注释炬的加薪:
if orig._backward_hooks or orig._forward_hooks or orig._forward_pre_hooks:
raise ValueError("Modules that have hooks assigned can't be compiled")
它对我有用,我设法编译了一个fastai模型。