如何计算沿轴的掩盖张量的中值?

时间:2020-08-27 17:54:45

标签: pytorch

我有一个尺寸为X的浮点数的张量n x m和一个尺寸为Y的布尔值的张量n x m。我想沿着一个轴计算X的平均值,中位数和最大值,但仅考虑XY中正确的值。类似于X[Y].mean(dim=1)。这是不可能的,因为X[Y]始终是一维张量。

编辑:

平均而言,我能够做到:

masked_X = X * Y
masked_X_mean = masked_X.sum(dim=1) / Y.sum(dim=1)

最大:

masked_X = X
masked_X[Y] = float('-inf')
masked_X_max = masked_X.max(dim=1)

但是对于中位数,我无法发挥创造力。有什么建议吗?

例如

X = torch.tensor([[1, 1, 1],
                  [2, 2, 4]]).type(torch.float32)
Y = torch.tensor([[0, 1, 0],
                  [1, 0, 1]]).type(torch.bool)

预期产量

mean = [1., 3.]
median = [1., 2.]
var = [0., 1.]

2 个答案:

答案 0 :(得分:0)

最大和中位数:

由于张量之一是布尔值,因此对原件和掩码进行元素逐个相乘,然后像这样计算max / median会很棒。

array = torch.randint(10, (4,4))
mask = torch.randint(2, (4,4)) #it will just generate the [0,1] values]
sol_max = torch.max(array*mask)
sol_median = torch.median(array*mask)

答案 1 :(得分:0)

这是我迄今为止最好的:

outs = []
for x, y in zip(X, Y):  # X, Y could be permuted to loop over desired axis
    out = torch.median(torch.masked_select(x, y))
    outs.append(out)
torch.tensor(outs)

如果有人有更好的解决方案,我们将不胜感激。