使用OpenNMT转移学习

时间:2019-05-02 14:04:23

标签: python pytorch transformer transfer-learning opennmt

我正在用OpenNMT-py训练MIDI音乐文件上的转换器模型,但是效果很差,因为我只能访问与我要研究的风格有关的小型数据集。为了帮助模型学习有用的东西,我想使用其他音乐风格更大的数据集进行预训练,然后使用较小的数据集对结果进行微调。

我正在考虑在预训练后冻结变压器的编码器侧,并让解码器部分自由进行微调。用OpenNMT-py怎么办?

1 个答案:

答案 0 :(得分:1)

请更详细地说明您的问题,并显示一些代码,这些代码将帮助您从SO社区获得富有成效的答复。

如果我在您的位置并想冻结神经网络组件,我会简单地做:

for name, param in self.encoder.named_parameters():
    param.requires_grad = False

在这里,我假设您有一个如下所示的NN模块。

class Net(nn.Module):
    def __init__(self, params):
        super(Net, self).__init__()

        self.encoder = TransformerEncoder(num_layers,
                                        d_model, 
                                        heads, 
                                        d_ff, 
                                        dropout, 
                                        embeddings,
                                        max_relative_positions)

    def foward(self):
        # write your code