我目前正在尝试扩展基于FairSeq / PyTorch的{{3}}。在训练期间,我需要训练两种编码器:一种使用目标样本,另一种使用源样本。
所以当前的前进功能如下:
def forward(self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs):
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
decoder_out = self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
return decoder_out
基于这个a model,我想要这样的东西:
def forward_test(self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs):
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
decoder_out = self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
return decoder_out
def forward_train(self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs):
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
autoencoder_out = self.encoder(tgt_tokens, src_lengths=src_lengths, **kwargs)
concat = some_concatination_func(encoder_out, autoencoder_out)
decoder_out = self.decoder(prev_output_tokens, encoder_out=concat, **kwargs)
return decoder_out
有没有办法做到这一点?
编辑: 这些是我的约束,因为我需要扩展 FairseqEncoderDecoderModel :
@register_model('transformer_mass')
class TransformerMASSModel(FairseqEncoderDecoderModel):
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
编辑2:
可以通过实现自己的条件来更改传递给Fairseq中正向函数的参数,例如,参见this idea,其中sample['net_input']
被传递给模型的__call__
函数,该函数调用forward
方法。
答案 0 :(得分:3)
首先,您应该始终使用并定义forward
,而不是在torch.nn.Module
实例上调用的其他方法。
因为它是PyTorch(trsvchn)定义的评估方法,所以绝对不会如see here所示eval()
超载。该方法允许模型内部的图层置于评估模式(例如,对Dropout
或BatchNorm
的推理模式等层的特定更改)。
此外,您应该使用__call__
魔术方法来调用它。为什么?因为钩子和其他PyTorch特定的东西都以这种方式正确注册。
第二,不要使用@Anant Mittal 建议的某些外部mode
字符串变量。这就是PyTorch中的train
变量的用途,标准区分模型是处于eval
模式还是train
模式。
话虽如此,您最好这样做:
import torch
class Network(torch.nn.Module):
def __init__(self):
super().__init__()
...
# You could split it into two functions but both should be called by forward
def forward(
self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs
):
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
if self.train:
return self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
autoencoder_out = self.encoder(tgt_tokens, src_lengths=src_lengths, **kwargs)
concat = some_concatination_func(encoder_out, autoencoder_out)
return self.decoder(prev_output_tokens, encoder_out=concat, **kwargs)
您可以(并且可以说应该)将以上内容分为两个单独的方法,但这并不算太糟糕,因为该函数相当简短且易于理解。如果可能的话,请坚持使用PyTorch的处理方式,而不是某些临时解决方案。不,反向传播不会有问题,为什么会有反向传播?
答案 1 :(得分:-1)
默认情况下,调用model()
会调用forward
方法,该方法在您的情况下会向前发展,因此您只需要在模型类中为测试/评估路径定义新方法,如下所示:
代码:
class FooBar(nn.Module):
"""Dummy Net for testing/debugging.
"""
def __init__(self):
super().__init__()
...
def forward(self, x):
# here will be train forward
...
def evaltest(self, x):
# here will be eval/test forward
...
示例:
model = FooBar() # initialize model
# train time
pred = model(x) # calls forward() method under the hood
# test/eval time
test_pred = model.evaltest(x)
评论: 我建议您将这两个正向路径拆分为2个单独的方法,因为它易于调试,并且可以避免在向后传播时出现一些可能的问题。