在使用seq2seq架构的NMT中,在推理期间,我们需要在训练阶段训练的嵌入变量作为GreedyEmbeddingHelper或BeamSearchDecoder的输入。
问题是,在使用Estimator API进行训练和推断的背景下,我们如何提取这个训练好的嵌入变量用于预测?
答案 0 :(得分:0)
我找到了一个基于以下stackoverflow answer的解决方案。对于预测阶段,您可以使用tf.contrib.framework.load_variable从训练有素且保存的Tensorflow模型中检索嵌入变量,如下所示:
glucose_mass
所以在我的情况下,我运行的代码来自包含已保存模型的同一文件夹,我的变量名称是'embed / embedding'。请注意,这仅适用于通过张量流模型训练的嵌入。否则,请参阅上面链接的答案。
要使用估算器API查找变量名称,可以使用方法get_variable_names()来获取图表中保存的所有变量名称的列表。