Pytorch-我的体重在该网络中共享吗?

时间:2018-12-06 22:50:11

标签: pytorch

我有以下网络要尝试三重丢失:

首先,我有一个自定义的卷积类ConvBlock(nn.Module):

def __init__(self, ngpu, input_c, output_c, mode=0):
    super(ConvBlock, self).__init__()
    self.ngpu = ngpu
    self.input_c = input_c 
    self.output_c = output_c
    self.mode = mode

    self.b1 = nn.Sequential(
            nn.Conv2d(input_c, output_c, 3, stride=1, padding=1),
            #nn.BatchNorm2d(output_c),
            nn.PReLU(),
        )
    self.b2 = nn.Sequential(
            nn.Conv2d(output_c, output_c, 3, stride=1, padding=1),
            #nn.BatchNorm2d(output_c),
            nn.PReLU(),
        )
    self.pool = nn.Sequential(
            nn.MaxPool2d(2, 2),
        )

def forward(self, input):

    batch_size = input.size(0)
    if self.mode == 0:
        b1 = self.b1(input)
        hidden = self.pool(b1)
        return hidden, b1
    elif self.mode == 1:
        b1 = self.b1(input)
        b2 = self.b2(b1)
        hidden = self.pool(b2)
        return hidden, b2
    elif self.mode == 2:
        b1 = self.b1(input)
        hidden = self.b2(b1)
        return hidden

我现在有一个编码器模块:

_Encoder类(nn.Module):

def __init__(self, ngpu,nc,nef,out_size,nz):
    super(_Encoder, self).__init__()
    self.ngpu = ngpu
    self.nc = nc 
    self.nef = nef
    self.out_size = out_size
    self.nz = nz

    self.c1 = ConvBlock(self.ngpu, nc, nef, 0)       # 3 - 64
    self.c2 = ConvBlock(self.ngpu, nef, nef*2, 0)    # 64-128
    self.c3 = ConvBlock(self.ngpu, nef*2, nef*4, 1)  # 128-256
    self.c4 = ConvBlock(self.ngpu, nef*4, nef*8, 1)  # 256 -512
    self.c5 = ConvBlock(self.ngpu, nef*8, nef*8, 2)  # 512-512

    # 8 because..the depth went from 32 to 32*8
    self.mean = nn.Linear(nef * 8 * out_size * (out_size/2), nz)
    self.logvar = nn.Linear(nef * 8 * out_size * (out_size/2), nz)

#for reparametrization trick 
def sampler(self, mean, logvar):  
    std = logvar.mul(0.5).exp_()
    if args.cuda:
        eps = torch.cuda.FloatTensor(std.size()).normal_()
    else:
        eps = torch.FloatTensor(std.size()).normal_()
    eps = Variable(eps)
    return eps.mul(std).add_(mean)

def forward(self, input):
    batch_size = input.size(0)
    if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
        c1_out, c1_x = nn.parallel.data_parallel(self.c1, input, range(self.ngpu))
        c2_out, c2_x = nn.parallel.data_parallel(self.c2, c1_out, range(self.ngpu))
        c3_out, c3_x = nn.parallel.data_parallel(self.c3, c2_out, range(self.ngpu))
        c4_out, c4_x = nn.parallel.data_parallel(self.c4, c3_out, range(self.ngpu))
        hidden = nn.parallel.data_parallel(self.c5, c4_out, range(self.ngpu))

        # hidden = nn.parallel.data_parallel(self.encoder, input, range(self.ngpu))
        hidden = hidden.view(batch_size, -1)
        mean = nn.parallel.data_parallel(self.mean, hidden, range(self.ngpu))
        logvar = nn.parallel.data_parallel(self.logvar, hidden, range(self.ngpu))
    else:
        c1_out, c1_x = self.c1(input)
        c2_out, c2_x = self.c2(c1_out)
        c3_out, c3_x = self.c3(c2_out)
        c4_out, c4_x = self.c4(c3_out)
        hidden = self.c5(c4_out)

        # hidden = self.encoder(input)
        hidden = hidden.view(batch_size, -1)
        mean, logvar = self.mean(hidden), self.logvar(hidden)

    latent_z = self.sampler(mean, logvar)
    if ADD_SKIP_CONNECTION:
        return latent_z,mean,logvar,{"c1_x":c1_x, "c2_x":c2_x, "c3_x":c3_x, "c4_x":c4_x}
    else:
        return latent_z,mean,logvar,{"c1_x":None, "c2_x":None, "c3_x":None, "c4_x":None}

我将编码器初始化为单个对象:

encoder = _Encoder(ngpu,nc,nef,out_size,nz)
encoder = encoder.cuda()

然后我要应用一些功能:

    latent_x,mean_x,logvar_x,skip_x = self.encoder(x)
    latent_y,mean_y,logvar_y,skip_y = self.encoder(y)
    latent_z,mean_z,logvar_z,skip_z = self.encoder(z)
    dist_a = F.pairwise_distance(mean_x, mean_y, 2)
    dist_b = F.pairwise_distance(mean_x, mean_z, 2)
    loss_triplet = triplet_loss(dist_a, dist_b, target)

    optimizer.zero_grad()
    loss_triplet.backward()
    optimizer.step()

我开始怀疑权重是否实际上是在3个编码器块之间共享。请帮我检查一下是否可以

0 个答案:

没有答案