pytorch中位数-是bug还是我使用错了

时间:2019-01-22 14:52:13

标签: python pytorch median torch

我试图获取2D torch.tensor每行的中值。但是与使用标准数组或numpy相比,结果不是我期望的

import torch
import numpy as np
from statistics import median

print(torch.__version__)
>>> 0.4.1

y = [[1, 2, 3, 5, 9, 1],[1, 2, 3, 5, 9, 1]]
median(y[0])
>>> 2.5

np.median(y,axis=1)
>>> array([2.5, 2.5])

yt = torch.tensor(y,dtype=torch.float32)
yt.median(1)[0]
>>> tensor([2., 2.])

2 个答案:

答案 0 :(得分:2)

看起来像这是本期提到的Torch的预期行为

https://github.com/pytorch/pytorch/issues/1837
https://github.com/torch/torch7/pull/182

上面链接中提到的推理

  

在有奇数个元素的情况下,Median返回“ middle”元素,否则返回“ before-middle”元素(也可以采用另一种约定来表示两个居中元素的均值,但这要多两倍)昂贵,所以我决定买这个。

答案 1 :(得分:0)

您可以使用pytorch模拟numpy中位数:

volume()