如何将参数从self.some_dictionary_of_modules添加到self.parameters?

时间:2019-03-19 09:48:55

标签: pytorch

考虑以下简单示例:

import torch

class MyModule(torch.nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.conv_0=torch.nn.Conv2d(3,32,3,stride=1,padding=0)
        self.blocks=torch.nn.ModuleList([
            torch.nn.Conv2d(3,32,3,stride=1,padding=0),
            torch.nn.Conv2d(32,64,3,stride=1,padding=0)])

        #the problematic part
        self.dict_block={"key_1": torch.nn.Conv2d(64,128,3,1,0),
                "key_2": torch.nn.Conv2d(56,1024,3,1,0)}

if __name__=="__main__":
    my_module=MyModule()
    print(my_module.parameters)

我在这里得到的输出是(注意self.dict_block中的参数丢失了)

<bound method Module.parameters of MyModule(
  (conv_0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
  (blocks): ModuleList(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
    (1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  )
)>

这意味着如果我要优化self.dict_block的参数,则需要使用

my_optimiser.add_param_group({"params": params_from_self_dict})

在使用我的优化器之前。但是,我认为可能会有更直接的选择,即将self.dict_block的参数添加到my_module_object的参数中。如here所述,接近nn.Parameter(...)的地方是caret,但这需要输入为张量。

1 个答案:

答案 0 :(得分:1)

找到了答案。发布它以防有人遇到相同问题:

调查<torch_install>/torch/nn/modules/container.py似乎有一个类torch.nn.ModuleDict正是这样做的。因此,在我提出的示例中,解决方案将是:

self.dict_block=torch.nn.ModuleDict({"key_1": torch.nn.Conv2d(64,128,3,1,0),
            "key_2": torch.nn.Conv2d(56,1024,3,1,0)})