在tensorflow的seq2seq函数中使用预先训练的字嵌入

时间:2016-08-29 17:23:00

标签: python machine-learning tensorflow recurrent-neural-network lstm

我正在使用tensorflow的seq2seq.py中的函数构建一个seq2seq模型,它们具有如下函数:

embedding_rnn_seq2seq(encoder_inputs, decoder_inputs, cell,
                          num_encoder_symbols, num_decoder_symbols,
                          embedding_size, output_projection=None,
                          feed_previous=False, dtype=dtypes.float32,
                          scope=None)

但是,似乎这个函数没有将预先训练好的嵌入作为输入,有没有什么方法可以将预训练的字嵌入作为此函数的输入?

1 个答案:

答案 0 :(得分:1)

您只需交出参数即可。读入嵌入(确保词汇表ID匹配)。然后,一旦初始化了所有变量,找到嵌入张量(迭代通过tf.all_variables来查找名称)。然后使用tf.assign用嵌入来覆盖随机初始化的嵌入。