如何根据2D索引收集logit?

时间:2018-08-05 22:33:57

标签: tensorflow keras

假设序列[[2, 1, 4], [3, 4, 2]]由预先训练的LSTM生成。其维数为(2*3),表示每个样本中的batch-size = 23 time steps

例如,总共有5 features,因此登录数可能是:

[[[0.2, 0.2, 0.1, 0.1, 0.2], 
  [0.3, 0.2, 0.1, 0.1, 0.1],
  [0.1, 0.2, 0.1, 0.1, 0.3]],
 [[0.2, 0.2, 0.1, 0.1, 0.2],
  [0.2, 0.2, 0.1, 0.1, 0.2],
  [0.2, 0.2, 0.1, 0.1, 0.2]]]

我想使用序列作为索引,以从每个样本和每个时间步的对数中获取相应的概率。对于上面的示例,我想要得到的最终结果是

[[0.1, 0.2, 0.3],[0.1, 0.2, 0.1]]

我知道我可能需要tf.stack(),但是我对如何处理尺寸感到困惑。感谢您的帮助!

1 个答案:

答案 0 :(得分:0)

我想我找到了办法。

tf.losses.sparse_softmax_cross_entropy(labels = None, logits = None) 

labels是索引,logits是模型的输出。