在一个mini_batch中,我有一个隐藏的单词嵌入张量为h
,形状(h)为(?,300),一个注意词嵌入张量attention_words
,其中shape(attention_words)是(?,num,300)。现在我想在B中找到A中每个单词中最接近的单词。在这种情况下,我使用余弦距离来测量,如
normed_attention_words = tf.nn.l2_normalize(attention_words, dim=2)
normed_hidden = tf.nn.l2_normalize(h, dim=1)
normed_hidden = tf.expand_dims(normed_hidden, 1) # (?,1,300)
#num is the number of attention words
normed_hidden = tf.tile(normed_hidden, [1, num, 1]) #(?,num,300)
cosine_similarity = tf.reduce_sum(tf.mul(normed_attention_words, normed_hidden), 2)
closest_words = tf.argmax(cosine_similarity, 1)
因此,形状(nearest_words)是(?,)。我想利用张量closest_words
作为切片原始attention_words
的索引。我已尝试多次但失败了(特别是tf.gather_nd,因为没有实现gather_nd的Gradient)。那么在这种情况下,如何解决这个问题呢?