我是Pytorch的初学者,我正在学习如何使用钩子。我想看看钩子是怎么回事。以下是我尝试过的内容
def get_features_hook(module,input,output):
print(input)
print(output)
handle_feat = alexnet.features[0].register_forward_hook(get_features_hook)
a = alexnet(input_data)
a.backward(torch.ones(1, num_class))
handle_feat.remove()
但是我得到了这个错误:
TypeError:get_features_hook()缺少1个必需的位置参数:“输出”
我也尝试过:
def get_features_hook(input,output)
and
def get_features_hook(self,input,output)
我得到了相同的结果,
TypeError:get_features_hook()缺少1个必需的位置参数:“输出”
我想知道是否有人可以帮助我解决该问题,并告诉我代码有什么问题。