跟踪火炬网络中的所有权重和梯度?

时间:2020-02-23 17:48:29

标签: python pytorch

我正在尝试掌握训练期间各层的统计(均值,标准差)属性。
我目前正在实现前向和后向挂钩,我相信这是成千上万的以前写过的东西,但是找不到任何有关此类回购的参考。
我的代码简单地将所有权重,输出和渐变汇总到字典中(以及随后的可视化):

def track_network(Net,moduleList):
    outputs = {}
    weights = {}
    gradients = {}
    layers = list(dict(moduleList.named_children()).values())

    def get_activation(name):
        def hook(model, input, output):
            outputs.setdefault(model.__dict__['module_name'], []).append(output.detach().cpu().view(-1).numpy())
            weights.setdefault(model.__dict__['module_name'], []).append(model.weight.detach().cpu().view(-1).numpy())
        return hook

    def get_grad_hook(name):
        def grad_hook(model, grad_input, grad_output):
            gradients.setdefault(model.__dict__['module_name'], []).append(grad_output)
        return grad_hook

    def register_childern(children_dict,parent=""):
        for module_name, module in children_dict.items():
            if not isinstance(module, torch.nn.Linear):
                register_childern(dict(module.named_children()),parent+module_name)
            else:
                module.__dict__['module_name'] = parent+module_name
                module.register_forward_hook(get_activation(module_name))
                module.register_backward_hook(get_grad_hook(module_name))

    register_childern(dict(Net.named_children()))

我的问题是:是否存在跟踪网络此类属性的标准/直接方法?

0 个答案:

没有答案