PyTorch逐元素过滤层

时间:2018-08-23 07:41:55

标签: python python-3.x neural-network pytorch

嗨,我想添加逐元素乘法层,以将输入复制到多通道,如图所示。 (因此,输入大小M x N和乘法滤波器大小M x N相同),如图所示

我想添加自定义初始化值以进行过滤,还希望他们在训练时获得渐变。但是,我在PyTorch中找不到按元素分类的滤镜层。我能做到吗?还是在PyTorch中是不可能的?

1 个答案:

答案 0 :(得分:2)

在pytorch中,您始终可以通过将其设为nn.Module的子类来实现自己的图层。您还可以使用nn.Parameter在图层中设置可训练的参数。
这种层的可能实现看起来像

import torch
from torch import nn

class TrainableEltwiseLayer(nn.Module)
  def __init__(self, n, h, w):
    super(TrainableEltwiseLayer, self).__init__()
    self.weights = nn.Parameter(torch.tensor(1, n, h, w))  # define the trainable parameter

  def forward(self, x):
    # assuming x is of size b-1-h-w
    return x * self.weights  # element-wise multiplication

您仍然需要担心初始化权重。研究nn.init的权重初始化方法。通常是在训练之前和加载任何存储的模型之前初始化所有网络的权重(因此,部分训练的模型可以覆盖随机初始化)。像

model = mymodel(*args, **kwargs)  # instantiate a model
for m in model.modules():
  if isinstance(m, nn.Conv2d):
     nn.init.normal_(m.weight)  # init for conv layers
  if isinstance(m, TrainableEltwiseLayer):
     nn.init.constant_(m.weights, 1)  # init your weights here...