我正在尝试将旧的TF1项目转换为TF2,并遇到名称范围问题。这是权限代码:
if mode != tf.estimator.ModeKeys.PREDICT:
decoder = tfa.seq2seq.BasicDecoder(
cell=decoder_cell,
sampler=sampler,
output_layer=output_layer
)
else:
decoder = tfa.seq2seq.BeamSearchDecoder(
cell=decoder_cell,
beam_width=beam_width,
length_penalty_weight=length_penalty_weight,
output_layer=output_layer
)
问题是BasicDecoder
不在其“ step”方法中添加名称范围,该范围随后由tfa.seq2seq.dynamic_decode
在代码中调用,而BeamSearchDecoder
添加了BeamSearchDecoderStep/
,导致训练时间与推断时间之间的变量名称不同。
因此,当我尝试加载经过训练的模型进行预测时,最终会遇到“在检查点中未找到Key BeamSearchDecoderStep / attention_wrapper / BahdanauAttention / kernel”之类的错误(请注意,在检查点中确实存在tention_wrapper / BahdanauAttention / kernel)
我想到的第一个解决方案是用相同的代码替换step
方法,但是没有添加name_scope的行,但这似乎是令人作呕的黑客。解决我的问题的正确方法是什么?
编辑:如果发生任何更改,我将使用Estimator API创建模型。