如何加快此Keras Attention计算?

时间:2018-03-08 14:30:47

标签: python tensorflow keras vectorization

我为AttentiveLSTMCellAttentiveLSTM(RNN)编写了一个自定义keras图层,与keras的 RNN方法一致。该注意机制由Bahdanau描述,其中,在编码器/解码器模型中,从编码器的所有输出和解码器的当前隐藏状态创建“上下文”向量。然后,我将每个时间步的上下文向量附加到输入。

该模型用于制作Dialog Agent,但与架构中的NMT模型(类似任务)非常相似。

然而,在添加这种注意机制时,我已经减慢了我的网络5倍的训练速度,我真的想知道如何编写代码的一部分,这样可以更有效地减慢它的速度办法。

计算的主要内容在这里完成:

h_tm1 = states[0]  # previous memory state
c_tm1 = states[1]  # previous carry state

# attention mechanism

# repeat the hidden state to the length of the sequence
_stm = K.repeat(h_tm1, self.annotation_timesteps)

# multiplty the weight matrix with the repeated (current) hidden state
_Wxstm = K.dot(_stm, self.kernel_w)

# calculate the attention probabilities
# self._uh is of shape (batch, timestep, self.units)
et = K.dot(activations.tanh(_Wxstm + self._uh), K.expand_dims(self.kernel_v))

at = K.exp(et)
at_sum = K.sum(at, axis=1)
at_sum_repeated = K.repeat(at_sum, self.annotation_timesteps)
at /= at_sum_repeated  # vector of size (batchsize, timesteps, 1)

# calculate the context vector
context = K.squeeze(K.batch_dot(at, self.annotations, axes=1), axis=1)

# append the context vector to the inputs
inputs = K.concatenate([inputs, context])
<{1}} call方法中的

(一次级)。

可以找到完整的代码here。如果有必要提供一些数据和方法来与模型进行交互,那么我就可以做到。

有什么想法吗?当然,如果这里有一些聪明的话,我会在GPU上进行培训。

2 个答案:

答案 0 :(得分:0)

我建议使用relu而不是tanh来训练你的模型,因为这个操作的计算速度要快得多。这将节省您的训练示例顺序的计算时间*每个示例的平均序列长度*时期数。

另外,我会评估附加上下文向量的性能改进,请记住这会减慢其他参数的迭代周期。如果它没有给你太多改进,那么可能值得尝试其他方法。

答案 1 :(得分:0)

您修改了LSTM类,该类非常适合CPU计算,但是您提到您正在使用GPU进行训练。

我建议研究cudnn-recurrent的实现 或进一步放入使用的tf part中。也许您可以在那里扩展代码。