如何将tf.nn.top_k中的索引与tf.gather_nd一起使用?

时间:2019-01-15 09:37:56

标签: python tensorflow

我正在尝试使用从tf.nn.top_k返回的索引从第二张量中提取值。

我尝试使用numpy类型索引以及直接使用tf.gather_nd,但是我发现索引错误。

#  temp_attention_weights of shape [I, B, 1, J]
top_values, top_indices = tf.nn.top_k(temp_attention_weights, k=top_k)

# top_indices of shape [I, B, 1, top_k], base_encoder_transformed of shape [I, B, 1, J]

# I now want to extract from base_encoder_transformed top_indices
base_encoder_transformed = tf.gather_nd(base_encoder_transformed, indices=top_indices)  

# base_encoder_transformed should be of shape [I, B, 1, top_k]

我注意到top_indices的格式错误,但是我似乎无法将其转换为在tf.gather_nd中使用,其中最里面的维用于索引base_encoder_transformed中的每个对应元素。有人知道将top_indices转换为正确格式的方法吗?

2 个答案:

答案 0 :(得分:3)

top_indices仅在最后一个轴上索引,您也需要为其余轴添加索引。使用tf.meshgrid很容易:

import tensorflow as tf

# Example input data
I = 4
B = 3
J = 5
top_k = 2
x = tf.reshape(tf.range(I * B * J), (I, B, 1, J)) % 7
# Top K
top_values, top_indices = tf.nn.top_k(x, k=top_k)
# Make indices for the rest of axes
ii, jj, kk, _ = tf.meshgrid(
    tf.range(I),
    tf.range(B),
    tf.range(1),
    tf.range(top_k),
    indexing='ij')
# Stack complete index
index = tf.stack([ii, jj, kk, top_indices], axis=-1)
# Get the same values again
top_values_2 = tf.gather_nd(x, index)
# Test
with tf.Session() as sess:
    v1, v2 = sess.run([top_values, top_values_2])
    print((v1 == v2).all())
    # True

答案 1 :(得分:1)

我看不到使用tf.gather_nd的理由。通过将tf.meshgridtf.gather参数一起使用,有一个更简单,更快速的解决方案(无需使用batch_dims)。

import tensorflow as tf

# Example input data
I = 4
B = 3
J = 5
top_k = 2
x = tf.reshape(tf.range(I * B * J), (I, B, 1, J)) % 7
# Top K
top_values, top_indices = tf.nn.top_k(x, k=top_k)
#Gather indices along last axis
top_values_2 = tf.gather(x, top_indices, batch_dims = 3)

tf.reduce_all(top_values_2 == top_values).numpy()
#True

请注意,在这种情况下,batch_dims是3,因为我们想从最后一个轴开始收集,并且x的秩是4。