Seq2Seq:如何使用指针网络和固定词汇在推理时进行解码

时间:2019-05-04 17:53:17

标签: tensorflow keras seq2seq

我正在构建一个文本摘要生成器Seq2Seq模型,该模型使用指针网络以概率p从输入中进行采样,否则它将使用固定的词汇表解码下一个目标单词。如何在推理时编写代码以解码一批示例,而不会为每个示例循环?

    def predict_batch(self, X):
        assert self.embeddings, "Call self.set_embeddings_layer first"
        X = self.embeddings(X)
        enc_states, h1, h2 = self.encoder(X)
        input_tokens = tf.convert_to_tensor([self.start_token] * X.shape[0])

        # put last encoder state as attention vec at start
        c_vec = h1
        outputs = []

        for _ in range(self.max_len):
            dec_input = self.embeddings(input_tokens)
            decoded_state, h1, h2 = self.decoder(dec_input, c_vec, [h1, h2])
            c_vec, _, pointer_prob  = self.attention(enc_states, 
                                                     decoded_state)

            # Compute switch probability to decide if to extract the next
            # word token with a pointer network or a fixed vocabulary
            switch_probs = self.pointer_switch(h1, c_vec)

            ...

计算切换概率后,需要根据这些概率执行不同的代码。

例如,如果switch_probs为[0.2,0.8,0.5],并且我生成了一些随机数,例如[0.4,0.7,0.6],则我需要为示例[0,2]执行函数A,并为示例B执行函数B例如[1]。

有没有一种方法可以避免每个示例的循环,而使用一些高效的Tensorflow API来做到这一点?

以下是我要完成的工作的一个示例: https://arxiv.org/pdf/1704.04368.pdf

0 个答案:

没有答案