我有一个形状为input_ids
的输入张量ID [B x T]
和一个形状为[B x T x D]
(B: Batch size, T: Sequence Length, D: Dimension)
的相应嵌入矩阵。输入的ID是词汇ID,并且嵌入矩阵包含相应的嵌入。
我要从嵌入矩阵中选择具有特定ID(例如103
)的那些元素。使用tf.where
和tf.gather_nd
可以很容易地做到这一点,但是我不知道怎么做,是将结果整理成一批[B x N x D]
,其中{{1} }是序列中具有该ID(N
)的令牌的最大数量。我想根据需要使用0张量作为填充。
代码可能会更好地显示(让我们说103
):
B=2, T=8, and D=3
我想从import tensorflow as tf
tf.enable_eager_execution()
input_ids = tf.constant([[ 101, 1996, 16360, 103, 1010, 1996, 4223, 1997],
[ 101, 103, 3793, 103, 2443, 2000, 103, 2469]])
embeddings = tf.random_normal((2,8,3))
# input ids have two sequences. first one has one 103 element, while second has 3.
中选择与embeddings
相对应的内容,并用零填充其余结果。
我可以这样:
input_ids==103
通常,这将导致indices= tf.where(tf.equal(input_ids, 103))
result = tf.gather_nd(indices=indices, params=embeddings)
#result.shape==[4x3]
# This will result in a [4x3] matrix where 4 = total number of 103 elements in the batch
# and 3 is their corresponding embeddings dimension
# Now I want to organize this into a batch of the
# same batch size as input, i.e., desired shape=(2x3)
# where first (1x3) row contains all token `103`'s embeddings
# in the first sequence but but second (1x3) row has only
# one token 103 embedding (second sequence has only one 103 token)
# the rest are padded with zeros.
张量(M =批次中103个令牌的总数)。我想要的是[M x D]
,其中(N =每个序列中最多103个令牌的数量,对于上述情况为3)。我希望描述清楚(有点难以解释确切的问题)。
我该如何实现?
答案 0 :(得分:1)
我认为当参数tf.gather_nd
为负值时,0
返回indices
的属性可以得到利用。
首先获取embeddings
中某些ID的索引值。
import tensorflow as tf
tf.enable_eager_execution()
input_ids = tf.constant([[ 101, 1996, 16360, 103, 1010, 1996, 4223, 1997],
[ 101, 103, 3793, 103, 2443, 2000, 103, 2469]])
embeddings = tf.random_normal((2,8,3))
condition = tf.equal(input_ids, 103)
indices_value= tf.where(condition)
# [[0 3]
# [1 1]
# [1 3]
# [1 6]]
然后我们应该获取每个序列的令牌数量和索引值的掩码。
length = tf.reduce_sum(tf.cast(condition,tf.int32),axis=-1)
# [1 3]
indices_mask = tf.sequence_mask(length,tf.reduce_max(length))
# [[ True False False]
# [ True True True]]
接下来,我们需要在每个序列中指定索引值的位置。
result_indices = tf.scatter_nd(tf.where(indices_mask),
indices_value+1,
(indices_mask.shape[0],indices_mask.shape[1],tf.rank(input_ids)))-1
# [[[ 0 3]
# [-1 -1]
# [-1 -1]]
#
# [[ 1 1]
# [ 1 3]
# [ 1 6]]]
最后,我们得到的结果是tf.gather_nd
。
result = tf.gather_nd(indices=result_indices, params=embeddings)
print(result)
# [[[ 1.22885 0.77642244 -0.82193506]
# [ 0. 0. 0. ]
# [ 0. 0. 0. ]]
#
# [[-0.0567691 0.07378497 -0.4799046 ]
# [-1.1627238 -1.994217 0.8443906 ]
# [ 0.776338 -0.25828102 -1.7915782 ]]]