我有以下设置:
B = Batchsize
N = Number of Objects
T = Number of Targets
L = Length of feature embedding per target
对于每个对象,我想关注一个目标。该模型通过将向量attention_weights
与arg shape=[B,N,T]
的argmax一起确定来决定要参加的目标:
pick = tf.math.argmax(attention_weights, axis=2)
因此pick
的形状为[B,N]
,每个条目都是一个索引。现在,我想使用这些索引来访问正确的目标特征
target_features.set_shape(target_features, [B, D, L])
features_picked = tf.some_function(target_features, pick)
我的问题是,tf.some_function
用什么?它与tf.gather有关吗?在这种情况下,我很难弄清楚如何使用它。
在此先感谢您的帮助!
PS:我正在使用tf。版本 ='1.13.1'
答案 0 :(得分:0)
我想出了以下解决方案,一旦我确认它正在执行应做的工作,我将接受答案:
我平铺了target_features
,因此它的形状为[B,N,T,L]
然后我这样做:
features_picked = tf.batch_gather(target_features, indices=pick)
其中features_picked
的形状为[B, N, 1, L]