我正在构建一个文本摘要生成器Seq2Seq模型,该模型使用指针网络以概率p从输入中进行采样,否则它将使用固定的词汇表解码下一个目标单词。如何在推理时编写代码以解码一批示例,而不会为每个示例循环?
def predict_batch(self, X):
assert self.embeddings, "Call self.set_embeddings_layer first"
X = self.embeddings(X)
enc_states, h1, h2 = self.encoder(X)
input_tokens = tf.convert_to_tensor([self.start_token] * X.shape[0])
# put last encoder state as attention vec at start
c_vec = h1
outputs = []
for _ in range(self.max_len):
dec_input = self.embeddings(input_tokens)
decoded_state, h1, h2 = self.decoder(dec_input, c_vec, [h1, h2])
c_vec, _, pointer_prob = self.attention(enc_states,
decoded_state)
# Compute switch probability to decide if to extract the next
# word token with a pointer network or a fixed vocabulary
switch_probs = self.pointer_switch(h1, c_vec)
...
计算切换概率后,需要根据这些概率执行不同的代码。
例如,如果switch_probs为[0.2,0.8,0.5],并且我生成了一些随机数,例如[0.4,0.7,0.6],则我需要为示例[0,2]执行函数A,并为示例B执行函数B例如[1]。
有没有一种方法可以避免每个示例的循环,而使用一些高效的Tensorflow API来做到这一点?
以下是我要完成的工作的一个示例: https://arxiv.org/pdf/1704.04368.pdf