torch.unique()中的参数“ dim”如何工作?

时间:2019-01-18 23:26:31

标签: pytorch

我试图提取矩阵每一行中的唯一值,并将它们返回到同一矩阵中(重复值设置为0),例如,我想进行变换

torch.Tensor(([1, 2, 3, 4, 3, 3, 4],
              [1, 6, 3, 5, 3, 5, 4]])

torch.Tensor(([1, 2, 3, 4, 0, 0, 0],
              [1, 6, 3, 5, 0, 0, 4]])

torch.Tensor(([1, 2, 3, 4, 0, 0, 0],
              [1, 6, 3, 5, 4, 0, 0]])

即行中的顺序无关紧要。我尝试使用pytorch.unique(),在文档中提到可以使用参数dim指定采用唯一值的维。但是,它似乎不适用于这种情况。

我尝试过:

output= torch.unique(torch.Tensor([[4,2,52,2,2],[5,2,6,6,5]]), dim = 1)

output

哪个给

tensor([[ 2.,  2.,  2.,  4., 52.],
        [ 2.,  5.,  6.,  5.,  6.]])

有人对此有特定的解决方法吗?如果可能的话,我正在尝试避免for循环。

2 个答案:

答案 0 :(得分:1)

必须承认unique函数有时会在没有给出适当示例和解释的情况下非常混乱。

dim参数指定要在矩阵张量上应用的尺寸。

例如,在2D矩阵中,dim=0将使操作垂直执行,而dim=1则是水平执行。

例如,让我们考虑一个带有dim=1的4x4矩阵。从下面的代码中可以看到,unique操作是逐行应用的。

您会注意到第一行和最后一行两次出现数字11。 Numpy and Torch这样做是为了保留最终矩阵的形状。

但是,如果不指定任何尺寸,则割炬将自动展平矩阵,然后对其应用unique,您将获得包含唯一数据的一维数组。

import torch

m = torch.Tensor([
    [11, 11, 12,11], 
    [13, 11, 12,11], 
    [16, 11, 12, 11],  
    [11, 11, 12, 11]
])

output, indices = torch.unique(m, sorted=True, return_inverse=True, dim=1)
print("Ori \n{}".format(m.numpy()))
print("Sorted \n{}".format(output.numpy()))
print("Indices \n{}".format(indices.numpy()))

# without specifying dimension
output, indices = torch.unique(m, sorted=True, return_inverse=True)
print("Sorted (no dim) \n{}".format(output.numpy()))

结果(dim = 1)

Ori
[[11. 11. 12. 11.]
 [13. 11. 12. 11.]
 [16. 11. 12. 11.]
 [11. 11. 12. 11.]]
Sorted
[[11. 11. 12.]
 [11. 13. 12.]
 [11. 16. 12.]
 [11. 11. 12.]]
Indices
[1 0 2 0]

结果(无尺寸)

Sorted (no dim)
[11. 12. 13. 16.]

答案 1 :(得分:0)

我第一次使用torch.unique时感到困惑。经过一些实验,我终于弄清楚了dim参数是如何工作的。 torch.unique的文档说:

  

counts(Tensor):(可选)如果return_counts为True,将存在一个附加的返回张量(与output或output.size(dim)相同的形状,如果指定了dim),表示每个唯一值的出现次数或张量

例如,如果您的输入张量是大小为n x m x k且为dim=2的3D张量,则unique将比较大小为n x m的k个矩阵。换句话说,它将把除dim 2以外的所有维度都视为张量,并进行比较。