在Tensorflow中实现注意

时间:2017-02-28 11:04:49

标签: python tensorflow

我想在TensorFlow中检查我的注意力实施是否正确。

基本上,我正在使用https://arxiv.org/pdf/1509.06664v1.pdf中提到的注意力。 (只是基线注意力,而不是逐字注意)。到目前为止,我实现它而不使用最后一个隐藏状态h_N。

def attention(hidden_states):
    '''
    hidden states (inputs) are seq_len x batch_size x dim 
    returns r, the weighted representation of the hidden states by attention vector a

    Note:I do not use the h_N vector and also I skip the last projection layer. 
    '''
    shape = hidden_states.get_shape().as_list()

    with tf.variable_scope('attention_{}'.format(name), reuse=reuse) as f:
        initializer = tf.random_uniform_initializer()

        # Initialize Parameters
        weights_Y = tf.get_variable("weights_Y", [self.in_size, self.in_size], initializer=initializer)
        weights_w = tf.get_variable("weights_w", [self.in_size, 1], initializer=initializer)

        # Equation => M  = tanh(W^{Y}Y)
        tmp = tf.reshape(hidden_states, [-1, shape[2]])
        Y = tf.matmul(tmp, weights_Y)
        Y = tf.reshape(Y, [shape[0], -1, shape[2]])
        Y = tf.tanh(Y, name='M_matrix')

        # Equation => a = softmax(Y w^T)
        Y = tf.reshape(Y, [-1, shape[2]])
        a = tf.matmul(Y, weights_w)
        a = tf.reshape(a, [-1, shape[0]])
        a = tf.nn.softmax(a, name='attention_vector')

        # Equation => r = Ya^T
        # This is the part I weight all hidden states by the attention vector
        a = tf.expand_dims(a, 2)
        H = tf.transpose(hidden_states, [1,2,0])  
        r = tf.matmul(H, a, name='R_vector')
        r = tf.reshape(r, [-1, shape[2]])

        # I skip the last projection layer  since I do not use h_N
        return r

此图表正确编译,运行和训练。 (损失正在减少等),但表现低于我的预期。如果我可以检查我是否做得对,我将不胜感激。

通常,

1)对于[?,seq_len,dim]矩阵乘以[dim,dim]的乘法。从[?,seq_len,dim]到[-1,dim]使用tf.reshape并使用[dim,dim]应用形状[-1,dim]的matmul然后重塑为[?,seq_len, dim]在matmul之后?

2)我注意到我获得了一个注意向量(?,seq_len)。所以我需要做(?,seq_len)x(?,dim,seq_len)。

从(?,seq_len)转换和expand_dims到(?,seq_len,1)是否正确然后做一个matmul(我认为这是batch_matmul在以前版本中的作用)。

提前致谢!

1 个答案:

答案 0 :(得分:1)

不确定TF1.0中的tf.einsum是否有效实施,但它会使计算非常优雅。

import tensorflow as tf
import numpy as np

batch_size = 3
seq_len = 5
dim = 2
# [batch_size x seq_len x dim]  -- hidden states
Y = tf.constant(np.random.randn(batch_size, seq_len, dim), tf.float32)
# [batch_size x dim]            -- h_N
h = tf.constant(np.random.randn(batch_size, dim), tf.float32)

initializer = tf.random_uniform_initializer()
W = tf.get_variable("weights_Y", [dim, dim], initializer=initializer)
w = tf.get_variable("weights_w", [dim], initializer=initializer)

# [batch_size x seq_len x dim]  -- tanh(W^{Y}Y)
M = tf.tanh(tf.einsum("aij,jk->aik", Y, W))
# [batch_size x seq_len]        -- softmax(Y w^T)
a = tf.nn.softmax(tf.einsum("aij,j->ai", M, w))
# [batch_size x dim]            -- Ya^T
r = tf.einsum("aij,ai->aj", Y, a)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    a_val, r_val = sess.run([a, r])
    print("a:", a_val, "\nr:", r_val)