从张量流中的索引(top_k)获取掩码

时间:2018-11-20 16:12:57

标签: python tensorflow attention-model

我有两个张量,a和b,a是top-k张量,b是遮罩张量。 a的形状为[batch_size,k],b的形状为[batch_size * seq_len],dtype为bool,全部由False初始化。 a的每一行都有k个整数,每个i的值i表示b的相应raw的第i个项应设置为True。

例如: b是[[False,False,False,False,False],[False,False,False,False,False]]。

a is [[0,4],[1,2]],将b中a的对应索引设置为True。 那么结果就是[[True,False,False,False,True],[False,True,True,False,False]]。

0 个答案:

没有答案