我想修改“ torch.nn.functional”中的某些函数,然后使用“ importlib.reload”重新加载新的“ torch.nn.functional”。但是我遇到运行时错误:function'avg_pool2d'已经有一个文档字符串。
我正在尝试对励磁bp(https://github.com/greydanus/excitationbp)进行一些工作,并且有一个名为“ eb.use_eb()”的操作可以修改“ torch.nn.functional”。函数eb.use_eb()用于修改“ torch.nn.linear”,“ torch.nn.functional.conv1d”等。我尝试修改utils.py中包括许多eb.use_eb()的代码。 。当我想运行原始的“ torch.nn.linear”时,必须运行eb.use_eb(False)。但是模块已加载到文件的开头(导入torch.nn.funtional)。因此,我想重新加载模块“ torch.nn.funtional”。 当我使用“ importlib.reload”时,遇到错误。
eb.use_eb(False,True)
importlib.reload(torch)
importlib.reload(torch.nn)
importlib.reload(torch.nn.functional)
print("gradient_evidence:"+str(gradient_evidence))
return torch.autograd.grad(contr_h_, target_h_, grad_outputs=gradient_evidence)[0]
RuntimeError Traceback (most recent call last)
ipython-input-15-29d0fca29882 in module()
----> 1 prob_inputs_one = eb.excitation_backprop(model, inputs, prob_outputs_one, contrastive=2)
2 #pdb.set_trace()
3 prob_inputs_true = eb.excitation_backprop(model, inputs, prob_outputs_true, contrastive=2)
~/code/excitationbp/excitationbp-master/excitationbp/utils.py in excitation_backprop(model, inputs, prob_outputs, contrastive, target_layer)
75 importlib.reload(torch)
76 importlib.reload(torch.nn)
---> 77 importlib.reload(torch.nn.functional)
78
79 gradient_evidence = torch.autograd.grad(top_h_, contr_h_, grad_outputs=prob_outputs.clone())[0]
~/anaconda3/envs/pytorch0.3/lib/python3.6/importlib/__init__.py in reload(module)
164 target = module
165 spec = module.__spec__ = _bootstrap._find_spec(name, pkgpath, target)
--> 166 _bootstrap._exec(spec, module)
167 # The module may have replaced itself in sys.modules!
168 return sys.modules[name]
~/anaconda3/envs/pytorch0.3/lib/python3.6/importlib/_bootstrap.py in _exec(spec, module)
~/anaconda3/envs/pytorch0.3/lib/python3.6/importlib/_bootstrap_external.py in exec_module(self, module)
~/anaconda3/envs/pytorch0.3/lib/python3.6/importlib/_bootstrap.py in _call_with_frames_removed(f, *args, **kwds)
~/anaconda3/envs/pytorch0.3/lib/python3.6/site-packages/torch/nn/functional.py in <module>()
286 count_include_pad: when True, will include the zero-padding in th
287 averaging calculation. Default: ``True``
--> 288 """)
289
290 avg_pool3d = _add_docstr(torch._C._nn.avg_pool3d, r"""
RuntimeError: function 'avg_pool2d' already has a docstring