假设有一个张量state
包含一个rnn状态列表,另一个张量prob
包含每个状态的概率。
state = tf.placeholder(tf.float32, [None, 49, 32])
print state.get_shape() # (?, 49, 32) (batch_size, candidate_size, state_size)
prob = tf.placeholder(tf.float32, [None, 49])
print prob.get_shape() # (?, 49) (batch_size, candidate_size)
# Now I want to fetch 7 states with top probabilities
_, indices = tf.nn.top_k(prob, 7)
print indices.get_shape() # (?, 7)
如何使用state
对indices
进行切片?
使用tf.gather(state, indices)
的问题在于它只会沿第一维(即批量维度)对state
进行切片。在这里,我们希望沿着第二维(长度为49)对其进行切片。