如何在急切执行中使用tf.contrib.seq2seq?

时间:2018-11-06 08:49:54

标签: python tensorflow machine-learning seq2seq

我正在研究一个使用tf.contrib.seq2seq模块的seq2seq模型,我希望将其构建为与图和渴望执行兼容。据我所知,tf.contrib.seq2seq尚不能在急切执行中起作用(如果我错了,请纠正我)。所以我想将其嵌入tfe.defun()中,但出现以下错误:

RuntimeError: tf.device does not support functions when eager execution is enabled.

在下一行:

final_outputs, final_state, final_sequence_lengths = \
                tf.contrib.seq2seq.dynamic_decode(decoder, impute_finished=True

我使用tfe.defun()的两个部分:

def train_decoder_model(self, function_encoder_outputs, decoder_inputs, target_sequence_length):
    with tf.name_scope('embedding'):
        #  The embedding layer expects integer instead of one-hot encodings.
        decoder_inputs_ints = tf.argmax(decoder_inputs, axis=-1)
        #  Perform the embedding on the decoder input.
        decoder_embedding = tf.nn.embedding_lookup(self._emb_matrix, decoder_inputs_ints)
    with tf.name_scope('decoder'):
        target_sequence_length = tf.cast(target_sequence_length, tf.int32)
        def graph_decoder():
            helper = tf.contrib.seq2seq.TrainingHelper(decoder_embedding, target_sequence_length)
            decoder = tf.contrib.seq2seq.BasicDecoder(self.decoder_cell, helper,
                                                      function_encoder_outputs, self.decoder_dense)
            final_outputs, final_state, final_sequence_lengths = \
                tf.contrib.seq2seq.dynamic_decode(decoder, impute_finished=True)
            return final_outputs
        graph_func = tfe.defun(graph_decoder)
        final_outputs = graph_func()
    return final_outputs

def loss(self, labels, logits, doc_length):
    def graph_loss():
        masks = tf.sequence_mask(doc_length, tf.reduce_max(doc_length), dtype=tf.float32, name='masks')
        return tf.contrib.seq2seq.sequence_loss(logits, tf.argmax(labels, -1), masks)
    graph_func = tfe.defun(graph_loss)
    return graph_func()

有什么线索可以解决这个问题吗?

0 个答案:

没有答案