PyTorch:用于培训和测试/验证的不同转发方法

时间:2019-11-01 06:48:54

标签: python-3.x neural-network pytorch transformer seq2seq

我目前正在尝试扩展基于FairSeq / PyTorch的{​​{3}}。在训练期间,我需要训练两种编码器:一种使用目标样本,另一种使用源样本。

所以当前的前进功能如下:

def forward(self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs):
    encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
    decoder_out = self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
    return decoder_out

基于这个a model,我想要这样的东西:

def forward_test(self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs):
    encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
    decoder_out = self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
    return decoder_out

def forward_train(self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs):
    encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
    autoencoder_out = self.encoder(tgt_tokens, src_lengths=src_lengths, **kwargs)
    concat = some_concatination_func(encoder_out, autoencoder_out)
    decoder_out = self.decoder(prev_output_tokens, encoder_out=concat, **kwargs)
    return decoder_out

有没有办法做到这一点?

编辑: 这些是我的约束,因为我需要扩展 FairseqEncoderDecoderModel

@register_model('transformer_mass')
class TransformerMASSModel(FairseqEncoderDecoderModel):
    def __init__(self, encoder, decoder):
        super().__init__(encoder, decoder) 

编辑2: 可以通过实现自己的条件来更改传递给Fairseq中正向函数的参数,例如,参见this idea,其中sample['net_input']被传递给模型的__call__函数,该函数调用forward方法。

2 个答案:

答案 0 :(得分:3)

首先,您应该始终使用并定义forward ,而不是在torch.nn.Module实例上调用的其他方法。

因为它是PyTorch(trsvchn)定义的评估方法,所以绝对不会如see here所示eval()超载。该方法允许模型内部的图层置于评估模式(例如,对DropoutBatchNorm的推理模式等层的特定更改)。

此外,您应该使用__call__魔术方法来调用它。为什么?因为钩子和其他PyTorch特定的东西都以这种方式正确注册。

第二,不要使用@Anant Mittal 建议的某些外部mode字符串变量。这就是PyTorch中的train变量的用途,标准区分模型是处于eval模式还是train模式。

话虽如此,您最好这样做:

import torch


class Network(torch.nn.Module):
    def __init__(self):
        super().__init__()
        ...

    # You could split it into two functions but both should be called by forward
    def forward(
        self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs
    ):
        encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
        if self.train:
            return self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
        autoencoder_out = self.encoder(tgt_tokens, src_lengths=src_lengths, **kwargs)
        concat = some_concatination_func(encoder_out, autoencoder_out)
        return self.decoder(prev_output_tokens, encoder_out=concat, **kwargs)

您可以(并且可以说应该)将以上内容分为两个单独的方法,但这并不算太糟糕,因为该函数相当简短且易于理解。如果可能的话,请坚持使用PyTorch的处理方式,而不是某些临时解决方案。不,反向传播不会有问题,为什么会有反向传播?

答案 1 :(得分:-1)

默认情况下,调用model()会调用forward方法,该方法在您的情况下会向前发展,因此您只需要在模型类中为测试/评估路径定义新方法,如下所示:

代码:

class FooBar(nn.Module):
    """Dummy Net for testing/debugging.
    """

    def __init__(self):
        super().__init__()
        ...

    def forward(self, x):
        # here will be train forward
        ...

    def evaltest(self, x):
        # here will be eval/test forward
        ...

示例:

model = FooBar()  # initialize model 

# train time
pred = model(x)   # calls forward() method under the hood

# test/eval time
test_pred = model.evaltest(x)

评论: 我建议您将这两个正向路径拆分为2个单独的方法,因为它易于调试,并且可以避免在向后传播时出现一些可能的问题。