Tensorflow:在训练期间使用BasicDecoder,在推理期间使用BeamSearchDecoder

时间:2020-06-10 04:03:11

标签: tensorflow

我正在尝试将旧的TF1项目转换为TF2,并遇到名称范围问题。这是权限代码:

if mode != tf.estimator.ModeKeys.PREDICT:
   decoder = tfa.seq2seq.BasicDecoder(
      cell=decoder_cell,
      sampler=sampler,
      output_layer=output_layer
   )
else:
   decoder = tfa.seq2seq.BeamSearchDecoder(
      cell=decoder_cell,
      beam_width=beam_width,
      length_penalty_weight=length_penalty_weight,
      output_layer=output_layer
   )

问题是BasicDecoder不在其“ step”方法中添加名称范围,该范围随后由tfa.seq2seq.dynamic_decode在代码中调用,而BeamSearchDecoder添加了BeamSearchDecoderStep/ ,导致训练时间与推断时间之间的变量名称不同。

因此,当我尝试加载经过训练的模型进行预测时,最终会遇到“在检查点中未找到Key BeamSearchDecoderStep / attention_wrapper / BahdanauAttention / kernel”之类的错误(请注意,在检查点中确实存在tention_wrapper / BahdanauAttention / kernel)

我想到的第一个解决方案是用相同的代码替换step方法,但是没有添加name_scope的行,但这似乎是令人作呕的黑客。解决我的问题的正确方法是什么?

编辑:如果发生任何更改,我将使用Estimator API创建模型。

0 个答案:

没有答案