我想使pixelCNN适合一维(即从另一个向量生成矢量)。
我在PixelCNN类的 init 中添加e_dim,z_dim-这意味着我想通过使用此模型从e_dim映射到z_dim。
该模型的代码为:
class PixelCNN(nn.Module):
def __init__(self, no_layers=8, kernel = 7, channels=64, device=None,
e_dim=xn, z_dim=yn):
super(PixelCNN, self).__init__()
self.no_layers = no_layers
self.kernel = kernel
self.channels = channels
self.layers = {}
self.device = device
self.Conv2d_1 = MaskedCNN('A', e_dim, channels, kernel, 1, kernel//2, bias=False)
self.BatchNorm2d_1 = nn.BatchNorm2d(channels)
self.ReLU_1= nn.ReLU(True)
self.Conv2d_2 = MaskedCNN('B',channels,channels, kernel, 1, kernel//2, bias=False)
self.BatchNorm2d_2 = nn.BatchNorm2d(channels)
self.ReLU_2= nn.ReLU(True)
self.Conv2d_3 = MaskedCNN('B',channels,channels, kernel, 1, kernel//2, bias=False)
self.BatchNorm2d_3 = nn.BatchNorm2d(channels)
self.ReLU_3= nn.ReLU(True)
self.Conv2d_4 = MaskedCNN('B',channels,channels, kernel, 1, kernel//2, bias=False)
self.BatchNorm2d_4 = nn.BatchNorm2d(channels)
self.ReLU_4= nn.ReLU(True)
self.Conv2d_5 = MaskedCNN('B',channels,channels, kernel, 1, kernel//2, bias=False)
self.BatchNorm2d_5 = nn.BatchNorm2d(channels)
self.ReLU_5= nn.ReLU(True)
self.Conv2d_6 = MaskedCNN('B',channels,channels, kernel, 1, kernel//2, bias=False)
self.BatchNorm2d_6 = nn.BatchNorm2d(channels)
self.ReLU_6= nn.ReLU(True)
self.Conv2d_7 = MaskedCNN('B',channels,channels, kernel, 1, kernel//2, bias=False)
self.BatchNorm2d_7 = nn.BatchNorm2d(channels)
self.ReLU_7= nn.ReLU(True)
self.Conv2d_8 = MaskedCNN('B',channels,channels, kernel, 1, kernel//2, bias=False)
self.BatchNorm2d_8 = nn.BatchNorm2d(channels)
self.ReLU_8= nn.ReLU(True)
self.out = nn.Conv2d(channels, z_dim, 1)
def forward(self, x):
x = self.Conv2d_1(x)
x = self.BatchNorm2d_1(x)
x = self.ReLU_1(x)
x = self.Conv2d_2(x)
x = self.BatchNorm2d_2(x)
x = self.ReLU_2(x)
x = self.Conv2d_3(x)
x = self.BatchNorm2d_3(x)
x = self.ReLU_3(x)
x = self.Conv2d_4(x)
x = self.BatchNorm2d_4(x)
x = self.ReLU_4(x)
x = self.Conv2d_5(x)
x = self.BatchNorm2d_5(x)
x = self.ReLU_5(x)
x = self.Conv2d_6(x)
x = self.BatchNorm2d_6(x)
x = self.ReLU_6(x)
x = self.Conv2d_7(x)
x = self.BatchNorm2d_7(x)
x = self.ReLU_7(x)
x = self.Conv2d_8(x)
x = self.BatchNorm2d_8(x)
x = self.ReLU_8(x)
return self.out(x)
其中MaskedCNN定义为:
class MaskedCNN(nn.Conv2d):
"""
Implementation of Masked CNN Class as explained in A Oord et. al.
Taken from https://github.com/jzbontar/pixelcnn-pytorch
"""
def __init__(self, mask_type, *args, **kwargs):
self.mask_type = mask_type
assert mask_type in ['A', 'B'], "Unknown Mask Type"
super(MaskedCNN, self).__init__(*args, **kwargs)
self.register_buffer('mask', self.weight.data.clone())
_, depth, height, width = self.weight.size()
self.mask.fill_(1) #fill the mask in ones
if mask_type =='A':
self.mask[:,:,height//2,width//2:] = 0
self.mask[:,:,height//2+1:,:] = 0
else:
self.mask[:,:,height//2,width//2+1:] = 0
self.mask[:,:,height//2+1:,:] = 0
def forward(self, x):
self.weight.data*=self.mask
return super(MaskedCNN, self).forward(x)
有人做过吗?我曾想过将MaskedCNN更改为可在1D上使用,但是我不确定如何更改为可使用蒙版。
谢谢:)