我正在使用tensorflow的seq2seq.py
中的函数构建一个seq2seq模型,它们具有如下函数:
embedding_rnn_seq2seq(encoder_inputs, decoder_inputs, cell,
num_encoder_symbols, num_decoder_symbols,
embedding_size, output_projection=None,
feed_previous=False, dtype=dtypes.float32,
scope=None)
但是,似乎这个函数没有将预先训练好的嵌入作为输入,有没有什么方法可以将预训练的字嵌入作为此函数的输入?
答案 0 :(得分:1)
您只需交出参数即可。读入嵌入(确保词汇表ID匹配)。然后,一旦初始化了所有变量,找到嵌入张量(迭代通过tf.all_variables来查找名称)。然后使用tf.assign用嵌入来覆盖随机初始化的嵌入。