使用BeamSearchDecoder时没有名为GatherTree的操作

时间:2017-08-23 04:42:56

标签: tensorflow nlp deep-learning sequence-to-sequence

我正在使用TensorFlow实现Seq2Seq模型。我的代码使用Greedy Decoder工作,但当我使用BeamSearchDecoder来提高性能时,我遇到了这个错误:

Traceback (most recent call last):
  File "/Users/MichaelChen/Projects/CN-isA-Relation-Extraction/isa_seq2seq/predict.py", line 83, in <module>
    out_file='result/result_wc_4.out', checkpoint=checkpoint)
  File "/Users/MichaelChen/Projects/CN-isA-Relation-Extraction/isa_seq2seq/predict.py", line 48, in predict
    loader = tf.train.import_meta_graph(checkpoint + '.meta')
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 1686, in import_meta_graph
    **kwargs)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/meta_graph.py", line 504, in import_scoped_meta_graph
producer_op_list=producer_op_list)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/importer.py", line 283, in import_graph_def
    raise ValueError('No op named %s in defined operations.' % node.op)
ValueError: No op named GatherTree in defined operations.

当我使用推断模块生成时发生此错误 输出:

with tf.Session(graph=loaded_graph) as sess:
    loader = tf.train.import_meta_graph(checkpoint + '.meta')
    loader.restore(sess, checkpoint)

    input_data = loaded_graph.get_tensor_by_name('inputs:0')
    logits = loaded_graph.get_tensor_by_name('predictions:0')
    src_seq_len = loaded_graph.get_tensor_by_name('source_sequence_length:0')
    tgt_seq_len = loaded_graph.get_tensor_by_name('target_sequence_length:0')

    for i in range(len(text)):
        if len(text[i].strip()) < 1:
            continue
        text_seq = src2seq_word(text[i], True)
        answer_logits = sess.run(logits, {input_data: [text_seq] * batch_size,
                                          tgt_seq_len: [len(text_seq)] * batch_size,
                                          src_seq_len: [len(text_seq)] * batch_size}
                                 )[0]
        pred_res = "".join([pp.id2c[i] for i in answer_logits if i != pad and i != eos])

程序在loader = tf.train.import_meta_graph(checkpoint + '.meta')

失败

我不知道我是否正确处理解码器的输出,所以这里是相应的代码:

# 5. Predicting decoder
# Share params with Training Deocder 
tiled_dec_start_state = tf.contrib.seq2seq.tile_batch(encoder_state, beam_width)
tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch(encoder_outputs, beam_width)
tiled_src_seq_len = tf.contrib.seq2seq.tile_batch(src_seq_len, beam_width)

with tf.variable_scope('decode', reuse=True):
    batch_size_tensor = tf.constant(batch_size)
    beam_decoder_cell = get_decoder_cell(tiled_encoder_outputs, tiled_src_seq_len, 2 * num_units)
    beam_initial_state = beam_decoder_cell.zero_state(batch_size_tensor * beam_width, tf.float32)
    beam_initial_state = beam_initial_state.clone(cell_state=tiled_dec_start_state)

    start_tokens = tf.tile(tf.constant([c2id['<GO>']], dtype=tf.int32), [batch_size], name='start_tokens')

    predicting_decoder = tf.contrib.seq2seq.BeamSearchDecoder(
        cell=beam_decoder_cell,
        embedding=decoder_embeddings,
        start_tokens=start_tokens,
        end_token=c2id['<EOS>'],
        initial_state=beam_initial_state,
        beam_width=beam_width,
        output_layer=output_layer
    )
    predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder=predicting_decoder, maximum_iterations=max_tgt_seq_len)                                              

处理输出:

training_decoder_output, predicting_decoder_output = seq2seq_model(params...)

training_logits = tf.identity(training_decoder_output.rnn_output, name='logits')
predicting_logits = tf.identity(predicting_decoder_output.predicted_ids[:,:,0], name='predictions')

另外,我在nmt模型中找到了一些东西 https://github.com/tensorflow/nmt/blob/77e6c55052ba31a8d733c94bb820d091c8156d35/nmt/model.py(第391行)

   if beam_width > 0:
      logits = tf.no_op()
      sample_id = outputs.predicted_ids
    else:
      logits = outputs.rnn_output
      sample_id = outputs.sample_id

这与我的错误有关吗?

提前致谢!

0 个答案:

没有答案