如何在张量流中切割具有动态形状的张量?

时间:2016-05-24 07:13:43

标签: slice tensorflow

假设有一个张量state包含一个rnn状态列表,另一个张量prob包含每个状态的概率。

state = tf.placeholder(tf.float32, [None, 49, 32])
print state.get_shape()  # (?, 49, 32) (batch_size, candidate_size, state_size)

prob = tf.placeholder(tf.float32, [None, 49])
print prob.get_shape() # (?, 49) (batch_size, candidate_size)

# Now I want to fetch 7 states with top probabilities
_, indices = tf.nn.top_k(prob, 7)
print indices.get_shape() # (?, 7)

如何使用stateindices进行切片?

编辑:

使用tf.gather(state, indices)的问题在于它只会沿第一维(即批量维度)对state进行切片。在这里,我们希望沿着第二维(长度为49)对其进行切片。

0 个答案:

没有答案