我有一个 torch.Tensor
形状的 (2, 2, 2)
(可以更大),其中的值在 [0, 1]
范围内标准化。
现在我得到一个正整数 K
,它告诉我我需要创建一个掩码,其中对于批次中的每个 2D 张量,如果它大于所有的 1/k
,则值为 1值,其他地方为 0。返回掩码也具有形状 (2, 2, 2)
。
例如,如果我有一个这样的批次:
tensor([[[1., 3.],
[2., 4.]],
[[5., 7.],
[9., 8.]]])
并让 K=2
,这意味着我必须屏蔽大于每个 2D 张量内所有值的 50% 的值。
在示例中,0.5 分位数是 2.5
和 7.5
,因此这是所需的输出:
tensor([[[0, 1],
[0, 1]],
[[0, 0],
[1, 1]]])
我试过了:
a = torch.tensor([[[0, 1],
[0, 1]],
[[0, 0],
[1, 1]]])
quantile = torch.tensor([torch.quantile(x, 1/K) for x in a])
torch.where(a > val, 1, 0)
但这就是结果:
tensor([[[0, 0],
[0, 0]],
[[1, 0],
[1, 1]]])
答案 0 :(得分:1)
{
"config": [{
"key1": "value1",
"key2": "value2",
"key3": [{
"key3.1": "value3.1",
"key3.2": "value3.2",
"key3.3": [{
"key3.3.1": "value3.3.1",
"key3.3.2": "value3.3.2"
}]
}]
}]}
在这个资源之后是:
t = torch.tensor([[[1., 3.],
[2., 4.]],
[[5., 7.],
[9., 8.]]])
t_flat = torch.reshape(t, (t.shape[0], -1))
quants = torch.quantile(t_flat, 1/K, dim=1)
quants = torch..reshape(quants, (quants.shape[0], 1, 1))
res = torch.where(t > val, 1, 0)
这是你想要的