来自单热编码的Tensorflow掩码

时间:2017-01-07 01:38:50

标签: python tensorflow

我的标签是examples = tf.placeholder(tf.int32, [batch_size])形式的OHE,其中每个示例都是int范围内的0:ohe_size

我的输出采用softmax概率分布的形式,形状为[batch_size, ohe_size]

我试图找出如何创建一个掩码,让我只给出每个例子的概率分布。 e.g。

probs = [[0.1, 0.6, 0.3]
         [0.2, 0.1, 0.7]
         [0.9, 0.1, 0.0]]
examples = [2, 2, 0]

some_mask_func(probs, example) # <- Need this function    
> [0.3, 0.7, 0.9]

1 个答案:

答案 0 :(得分:2)

如果我理解你的例子,你需要tf.gather_nd

range = tf.range(tf.shape(examples)[0])
indices = tf.pack([range, examples], axis=1)
result = tf.gather_nd(probs, indices)