考虑以下简单示例:
import torch
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.conv_0=torch.nn.Conv2d(3,32,3,stride=1,padding=0)
self.blocks=torch.nn.ModuleList([
torch.nn.Conv2d(3,32,3,stride=1,padding=0),
torch.nn.Conv2d(32,64,3,stride=1,padding=0)])
#the problematic part
self.dict_block={"key_1": torch.nn.Conv2d(64,128,3,1,0),
"key_2": torch.nn.Conv2d(56,1024,3,1,0)}
if __name__=="__main__":
my_module=MyModule()
print(my_module.parameters)
我在这里得到的输出是(注意self.dict_block
中的参数丢失了)
<bound method Module.parameters of MyModule(
(conv_0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
(blocks): ModuleList(
(0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
(1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
)
)>
这意味着如果我要优化self.dict_block
的参数,则需要使用
my_optimiser.add_param_group({"params": params_from_self_dict})
在使用我的优化器之前。但是,我认为可能会有更直接的选择,即将self.dict_block
的参数添加到my_module_object
的参数中。如here所述,接近nn.Parameter(...)
的地方是caret
,但这需要输入为张量。
答案 0 :(得分:1)
找到了答案。发布它以防有人遇到相同问题:
调查<torch_install>/torch/nn/modules/container.py
似乎有一个类torch.nn.ModuleDict
正是这样做的。因此,在我提出的示例中,解决方案将是:
self.dict_block=torch.nn.ModuleDict({"key_1": torch.nn.Conv2d(64,128,3,1,0),
"key_2": torch.nn.Conv2d(56,1024,3,1,0)})