我有矩阵(3 x 15)dummies
,其中标记序列作为行:
[[ 1 66 67 68 0 0 0 0 0 0 0 0 0 0 0]
[ 1 66 67 66 68 66 67 66 0 0 0 0 0 0 0]
[ 1 66 67 68 18 19 20 21 22 23 24 25 26 17 0]]
此外,还有一个张量probs
,形状为(3 x 15 x n_tokens),具有令牌概率。
从probs
中,我只需要选择dummies
中令牌的概率。
我认为,可以将矩阵用作张量的索引,但是我还没有找到如何做的。
答案 0 :(得分:1)
您可以这样做:
import tensorflow as tf
dummies = ...
probs = ...
s = tf.shape(dummies)
i = tf.range(s[0])
j = tf.range(s[1])
ii, jj = tf.meshgrid(i, j, indexing='ij')
idx = tf.stack([ii, jj, dummies], axis=-1)
result = tf.gather_nd(probs, idx)