我正在使用TensorFlow实现Seq2Seq模型。我的代码使用Greedy Decoder工作,但当我使用BeamSearchDecoder来提高性能时,我遇到了这个错误:
Traceback (most recent call last):
File "/Users/MichaelChen/Projects/CN-isA-Relation-Extraction/isa_seq2seq/predict.py", line 83, in <module>
out_file='result/result_wc_4.out', checkpoint=checkpoint)
File "/Users/MichaelChen/Projects/CN-isA-Relation-Extraction/isa_seq2seq/predict.py", line 48, in predict
loader = tf.train.import_meta_graph(checkpoint + '.meta')
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 1686, in import_meta_graph
**kwargs)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/meta_graph.py", line 504, in import_scoped_meta_graph
producer_op_list=producer_op_list)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/importer.py", line 283, in import_graph_def
raise ValueError('No op named %s in defined operations.' % node.op)
ValueError: No op named GatherTree in defined operations.
当我使用推断模块生成时发生此错误 输出:
with tf.Session(graph=loaded_graph) as sess:
loader = tf.train.import_meta_graph(checkpoint + '.meta')
loader.restore(sess, checkpoint)
input_data = loaded_graph.get_tensor_by_name('inputs:0')
logits = loaded_graph.get_tensor_by_name('predictions:0')
src_seq_len = loaded_graph.get_tensor_by_name('source_sequence_length:0')
tgt_seq_len = loaded_graph.get_tensor_by_name('target_sequence_length:0')
for i in range(len(text)):
if len(text[i].strip()) < 1:
continue
text_seq = src2seq_word(text[i], True)
answer_logits = sess.run(logits, {input_data: [text_seq] * batch_size,
tgt_seq_len: [len(text_seq)] * batch_size,
src_seq_len: [len(text_seq)] * batch_size}
)[0]
pred_res = "".join([pp.id2c[i] for i in answer_logits if i != pad and i != eos])
程序在loader = tf.train.import_meta_graph(checkpoint + '.meta')
我不知道我是否正确处理解码器的输出,所以这里是相应的代码:
# 5. Predicting decoder
# Share params with Training Deocder
tiled_dec_start_state = tf.contrib.seq2seq.tile_batch(encoder_state, beam_width)
tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch(encoder_outputs, beam_width)
tiled_src_seq_len = tf.contrib.seq2seq.tile_batch(src_seq_len, beam_width)
with tf.variable_scope('decode', reuse=True):
batch_size_tensor = tf.constant(batch_size)
beam_decoder_cell = get_decoder_cell(tiled_encoder_outputs, tiled_src_seq_len, 2 * num_units)
beam_initial_state = beam_decoder_cell.zero_state(batch_size_tensor * beam_width, tf.float32)
beam_initial_state = beam_initial_state.clone(cell_state=tiled_dec_start_state)
start_tokens = tf.tile(tf.constant([c2id['<GO>']], dtype=tf.int32), [batch_size], name='start_tokens')
predicting_decoder = tf.contrib.seq2seq.BeamSearchDecoder(
cell=beam_decoder_cell,
embedding=decoder_embeddings,
start_tokens=start_tokens,
end_token=c2id['<EOS>'],
initial_state=beam_initial_state,
beam_width=beam_width,
output_layer=output_layer
)
predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder=predicting_decoder, maximum_iterations=max_tgt_seq_len)
处理输出:
training_decoder_output, predicting_decoder_output = seq2seq_model(params...)
training_logits = tf.identity(training_decoder_output.rnn_output, name='logits')
predicting_logits = tf.identity(predicting_decoder_output.predicted_ids[:,:,0], name='predictions')
另外,我在nmt模型中找到了一些东西 https://github.com/tensorflow/nmt/blob/77e6c55052ba31a8d733c94bb820d091c8156d35/nmt/model.py(第391行)
if beam_width > 0:
logits = tf.no_op()
sample_id = outputs.predicted_ids
else:
logits = outputs.rnn_output
sample_id = outputs.sample_id
这与我的错误有关吗?
提前致谢!