PixelCNN-如何使我在1维上工作

时间:2019-06-25 11:58:51

标签: python deep-learning pytorch

我想使pixelCNN适合一维(即从另一个向量生成矢量)。

我在PixelCNN类的 init 中添加e_dim,z_dim-这意味着我想通过使用此模型从e_dim映射到z_dim。

该模型的代码为:

   class PixelCNN(nn.Module):
   def __init__(self, no_layers=8, kernel = 7, channels=64, device=None, 
                                                      e_dim=xn, z_dim=yn):
       super(PixelCNN, self).__init__()
       self.no_layers = no_layers
       self.kernel = kernel
       self.channels = channels
       self.layers = {}
       self.device = device

    self.Conv2d_1 = MaskedCNN('A', e_dim, channels, kernel, 1, kernel//2, bias=False)
    self.BatchNorm2d_1 = nn.BatchNorm2d(channels)
    self.ReLU_1= nn.ReLU(True)

    self.Conv2d_2 = MaskedCNN('B',channels,channels, kernel, 1, kernel//2, bias=False)
    self.BatchNorm2d_2 = nn.BatchNorm2d(channels)
    self.ReLU_2= nn.ReLU(True)

    self.Conv2d_3 = MaskedCNN('B',channels,channels, kernel, 1, kernel//2, bias=False)
    self.BatchNorm2d_3 = nn.BatchNorm2d(channels)
    self.ReLU_3= nn.ReLU(True)

    self.Conv2d_4 = MaskedCNN('B',channels,channels, kernel, 1, kernel//2, bias=False)
    self.BatchNorm2d_4 = nn.BatchNorm2d(channels)
    self.ReLU_4= nn.ReLU(True)

    self.Conv2d_5 = MaskedCNN('B',channels,channels, kernel, 1, kernel//2, bias=False)
    self.BatchNorm2d_5 = nn.BatchNorm2d(channels)
    self.ReLU_5= nn.ReLU(True)

    self.Conv2d_6 = MaskedCNN('B',channels,channels, kernel, 1, kernel//2, bias=False)
    self.BatchNorm2d_6 = nn.BatchNorm2d(channels)
    self.ReLU_6= nn.ReLU(True)

    self.Conv2d_7 = MaskedCNN('B',channels,channels, kernel, 1, kernel//2, bias=False)
    self.BatchNorm2d_7 = nn.BatchNorm2d(channels)
    self.ReLU_7= nn.ReLU(True)

    self.Conv2d_8 = MaskedCNN('B',channels,channels, kernel, 1, kernel//2, bias=False)
    self.BatchNorm2d_8 = nn.BatchNorm2d(channels)
    self.ReLU_8= nn.ReLU(True)

    self.out = nn.Conv2d(channels, z_dim, 1)

def forward(self, x):
    x = self.Conv2d_1(x)
    x = self.BatchNorm2d_1(x)
    x = self.ReLU_1(x)
    x = self.Conv2d_2(x)
    x = self.BatchNorm2d_2(x)
    x = self.ReLU_2(x)
    x = self.Conv2d_3(x)
    x = self.BatchNorm2d_3(x)
    x = self.ReLU_3(x)
    x = self.Conv2d_4(x)
    x = self.BatchNorm2d_4(x)
    x = self.ReLU_4(x)
    x = self.Conv2d_5(x)
    x = self.BatchNorm2d_5(x)
    x = self.ReLU_5(x)
    x = self.Conv2d_6(x)
    x = self.BatchNorm2d_6(x)
    x = self.ReLU_6(x)
    x = self.Conv2d_7(x)
    x = self.BatchNorm2d_7(x)
    x = self.ReLU_7(x)
    x = self.Conv2d_8(x)
    x = self.BatchNorm2d_8(x)
    x = self.ReLU_8(x)

    return self.out(x) 

其中MaskedCNN定义为:

   class MaskedCNN(nn.Conv2d):
"""
Implementation of Masked CNN Class as explained in A Oord et. al. 
Taken from https://github.com/jzbontar/pixelcnn-pytorch
"""

def __init__(self, mask_type, *args, **kwargs):
    self.mask_type = mask_type
    assert mask_type in ['A', 'B'], "Unknown Mask Type"
    super(MaskedCNN, self).__init__(*args, **kwargs)
    self.register_buffer('mask', self.weight.data.clone())

    _, depth, height, width = self.weight.size()
    self.mask.fill_(1) #fill the mask in ones
    if mask_type =='A':
        self.mask[:,:,height//2,width//2:] = 0
        self.mask[:,:,height//2+1:,:] = 0
    else:
        self.mask[:,:,height//2,width//2+1:] = 0
        self.mask[:,:,height//2+1:,:] = 0


def forward(self, x):
    self.weight.data*=self.mask
    return super(MaskedCNN, self).forward(x)

有人做过吗?我曾想过将MaskedCNN更改为可在1D上使用,但是我不确定如何更改为可使用蒙版。

谢谢:)

0 个答案:

没有答案