我有一个T
形状的张量Batch_Size x Num_Items x Item_Dimension
和另一个P
形状的张量Batch_Size x Num_Items
,其中每批P中的Num_Items值总和为1(概率为每个批次的项目分配)。我想不根据概率分布P从T替换N
个项目。所得张量应为Batch_Size x N x Item_Dimension
。我该怎么办?
答案 0 :(得分:1)
看看 https://github.com/tensorflow/tensorflow/issues/9260
尽管请注意,我认为您需要Logit而不是概率来进行Gumbel最大采样。