切片指数为张量的张量

时间:2016-09-30 12:15:43

标签: tensorflow

在一个mini_batch中,我有一个隐藏的单词嵌入张量为h,形状(h)为(?,300),一个注意词嵌入张量attention_words,其中shape(attention_words)是(?,num,300)。现在我想在B中找到A中每个单词中最接近的单词。在这种情况下,我使用余弦距离来测量,如

    normed_attention_words = tf.nn.l2_normalize(attention_words, dim=2)
    normed_hidden = tf.nn.l2_normalize(h, dim=1) 

    normed_hidden = tf.expand_dims(normed_hidden, 1) # (?,1,300)

    #num is the number of attention words
    normed_hidden = tf.tile(normed_hidden, [1, num, 1]) #(?,num,300)

    cosine_similarity = tf.reduce_sum(tf.mul(normed_attention_words, normed_hidden), 2)

    closest_words = tf.argmax(cosine_similarity, 1)

因此,形状(nearest_words)是(?,)。我想利用张量closest_words作为切片原始attention_words的索引。我已尝试多次但失败了(特别是tf.gather_nd,因为没有实现gather_nd的Gradient)。那么在这种情况下,如何解决这个问题呢?

0 个答案:

没有答案