Pytorch:创建一个大于批处理中每个 2D 张量的第 n 个分位数的掩码

时间:2021-04-26 14:05:06

标签: python pytorch

我有一个 torch.Tensor 形状的 (2, 2, 2)(可以更大),其中的值在 [0, 1] 范围内标准化。

现在我得到一个正整数 K,它告诉我我需要创建一个掩码,其中对于批次中的每个 2D 张量,如果它大于所有的 1/k,则值为 1值,其他地方为 0。返回掩码也具有形状 (2, 2, 2)

例如,如果我有一个这样的批次:

tensor([[[1., 3.],
         [2., 4.]],
        [[5., 7.],
         [9., 8.]]])

并让 K=2,这意味着我必须屏蔽大于每个 2D 张量内所有值的 50% 的值。

在示例中,0.5 分位数是 2.57.5,因此这是所需的输出:

tensor([[[0, 1],
         [0, 1]],
        [[0, 0],
         [1, 1]]])

我试过了:

a = torch.tensor([[[0, 1],
                   [0, 1]],
                  [[0, 0],
                   [1, 1]]])
quantile = torch.tensor([torch.quantile(x, 1/K) for x in a])
torch.where(a > val, 1, 0)

但这就是结果:

tensor([[[0, 0],
         [0, 0]],
        [[1, 0],
         [1, 1]]])

1 个答案:

答案 0 :(得分:1)

{
"config": [{
    "key1": "value1",
    "key2": "value2",
    "key3": [{
        "key3.1": "value3.1",
        "key3.2": "value3.2",
        "key3.3": [{
            "key3.3.1": "value3.3.1",
            "key3.3.2": "value3.3.2"
        }]
    }]
}]}

在这个资源之后是:

t = torch.tensor([[[1., 3.],
         [2., 4.]],
        [[5., 7.],
         [9., 8.]]])

t_flat = torch.reshape(t, (t.shape[0], -1))
quants = torch.quantile(t_flat, 1/K, dim=1)
quants = torch..reshape(quants, (quants.shape[0], 1, 1))
res = torch.where(t > val, 1, 0)

这是你想要的