Tensorflow:如何为每个样本批量索引一个元素

时间:2019-10-21 11:10:05

标签: tensorflow indexing

我有以下设置:

B = Batchsize
N = Number of Objects
T = Number of Targets
L = Length of feature embedding per target

对于每个对象,我想关注一个目标。该模型通过将向量attention_weights与arg shape=[B,N,T]的argmax一起确定来决定要参加的目标:

pick = tf.math.argmax(attention_weights, axis=2)

因此pick的形状为[B,N],每个条目都是一个索引。现在,我想使用这些索引来访问正确的目标特征

target_features.set_shape(target_features, [B, D, L])
features_picked = tf.some_function(target_features, pick)

我的问题是,tf.some_function用什么?它与tf.gather有关吗?在这种情况下,我很难弄清楚如何使用它。

在此先感谢您的帮助!

PS:我正在使用tf。版本 ='1.13.1'

1 个答案:

答案 0 :(得分:0)

我想出了以下解决方案,一旦我确认它正在执行应做的工作,我将接受答案:

我平铺了target_features,因此它的形状为[B,N,T,L]

然后我这样做:


features_picked = tf.batch_gather(target_features, indices=pick)

其中features_picked的形状为[B, N, 1, L]