Tensorflow,从具有特定令牌ID的输入中选择嵌入并批量处理结果

时间:2019-06-26 23:35:55

标签: python tensorflow neural-network

我有一个形状为input_ids的输入张量ID [B x T]和一个形状为[B x T x D] (B: Batch size, T: Sequence Length, D: Dimension)的相应嵌入矩阵。输入的ID是词汇ID,并且嵌入矩阵包含相应的嵌入。

我要从嵌入矩阵中选择具有特定ID(例如103)的那些元素。使用tf.wheretf.gather_nd可以很容易地做到这一点,但是我不知道怎么做,是将结果整理成一批[B x N x D],其中{{1} }是序列中具有该ID(N)的令牌的最大数量。我想根据需要使用0张量作为填充。

代码可能会更好地显示(让我们说103):

B=2, T=8, and D=3

我想从import tensorflow as tf tf.enable_eager_execution() input_ids = tf.constant([[ 101, 1996, 16360, 103, 1010, 1996, 4223, 1997], [ 101, 103, 3793, 103, 2443, 2000, 103, 2469]]) embeddings = tf.random_normal((2,8,3)) # input ids have two sequences. first one has one 103 element, while second has 3. 中选择与embeddings相对应的内容,并用零填充其余结果。 我可以这样:

input_ids==103

通常,这将导致indices= tf.where(tf.equal(input_ids, 103)) result = tf.gather_nd(indices=indices, params=embeddings) #result.shape==[4x3] # This will result in a [4x3] matrix where 4 = total number of 103 elements in the batch # and 3 is their corresponding embeddings dimension # Now I want to organize this into a batch of the # same batch size as input, i.e., desired shape=(2x3) # where first (1x3) row contains all token `103`'s embeddings # in the first sequence but but second (1x3) row has only # one token 103 embedding (second sequence has only one 103 token) # the rest are padded with zeros. 张量(M =批次中103个令牌的总数)。我想要的是[M x D],其中(N =每个序列中最多103个令牌的数量,对于上述情况为3)。我希望描述清楚(有点难以解释确切的问题)。

我该如何实现?

1 个答案:

答案 0 :(得分:1)

我认为当参数tf.gather_nd为负值时,0返回indices的属性可以得到利用。

首先获取embeddings中某些ID的索引值。

import tensorflow as tf
tf.enable_eager_execution()

input_ids = tf.constant([[  101,  1996, 16360,  103,  1010,  1996,  4223,  1997],
                        [  101,  103,  3793,  103,  2443,  2000,  103,  2469]])
embeddings = tf.random_normal((2,8,3))

condition = tf.equal(input_ids, 103)
indices_value=  tf.where(condition)
# [[0 3]
#  [1 1]
#  [1 3]
#  [1 6]]

然后我们应该获取每个序列的令牌数量和索引值的掩码。

length = tf.reduce_sum(tf.cast(condition,tf.int32),axis=-1)
# [1 3]
indices_mask = tf.sequence_mask(length,tf.reduce_max(length))
# [[ True False False]
#  [ True  True  True]]

接下来,我们需要在每个序列中指定索引值的位置。

result_indices = tf.scatter_nd(tf.where(indices_mask),
                               indices_value+1,
                               (indices_mask.shape[0],indices_mask.shape[1],tf.rank(input_ids)))-1
# [[[ 0  3]
#   [-1 -1]
#   [-1 -1]]
#
#  [[ 1  1]
#   [ 1  3]
#   [ 1  6]]]

最后,我们得到的结果是tf.gather_nd

result = tf.gather_nd(indices=result_indices, params=embeddings)
print(result)
# [[[ 1.22885     0.77642244 -0.82193506]
#   [ 0.          0.          0.        ]
#   [ 0.          0.          0.        ]]
# 
#  [[-0.0567691   0.07378497 -0.4799046 ]
#   [-1.1627238  -1.994217    0.8443906 ]
#   [ 0.776338   -0.25828102 -1.7915782 ]]]