在 pytorch 中对概率分布的张量进行采样

时间:2021-05-18 19:23:39

标签: pytorch

我想对形状为 (N, C, H, W) 的概率分布张量进行采样,其中维度 1(大小为 C)包含具有“C”可能性的归一化概率分布。是否有 pytorch 函数可以有效地并行采样张量中的所有分布?我只需要对每个分布采样一次,因此结果可以是形状相同的单热张量,也可以是形状为 (N, 1, H, W) 的指数张量。

1 个答案:

答案 0 :(得分:0)

我没有看到要采样的单个函数,但我能够通过计算累积概率、独立采样每个点,然后选择在分布维度中采样为 1 的第一个点来分几个步骤对张量进行采样:

reverse_cumulative = torch.flip(torch.cumsum(torch.flip(probabilities, [1]), dim=1), [1])
cumulative = probabilities / reverse_cumulative
sampled = (torch.rand(cumulative.shape, device=device()) <= cumulative)
idxs = sampled * one_hot
idxs[~sampled] = self.tile_count
sampled_idxs = idxs.min(dim=1).indices
相关问题