火炬张量的逻辑索引

时间:2016-03-31 20:32:09

标签: indexing torch

我正在寻找一种优雅的方法来选择满足某些约束条件的火炬张量子集。 例如,说我有:

A = torch.rand(10,2)-1

S是10x1张量,

sel = torch.ge(S,5) -- this is a ByteTensor

我希望能够进行逻辑索引,如下所示:

A1 = A[sel]

但这并不奏效。 因此index函数接受LongTensor,但我找不到将S转换为LongTensor的简单方法,除了以下内容:

sel = torch.nonzero(sel)

返回K×2张量(K是S> = 5的值的数量)。那么我必须将它转换为一维数组,最后允许我索引A:

A:index(1,torch.squeeze(sel:select(2,1)))

这非常麻烦;在例如Matlab所有我必须做的是

A(S>=5,:)

有人能建议更好的方法吗?

2 个答案:

答案 0 :(得分:5)

一种可能的选择是:

sel = S:ge(5):expandAs(A)   -- now you can use this mask with the [] operator
A1 = A[sel]:unfold(1, 2, 2) -- unfold to get back a 2D tensor

示例:

> A = torch.rand(3,2)-1
-0.0047 -0.7976
-0.2653 -0.4582
-0.9713 -0.9660
[torch.DoubleTensor of size 3x2]

> S = torch.Tensor{{6}, {1}, {5}}
 6
 1
 5
[torch.DoubleTensor of size 3x1]

> sel = S:ge(5):expandAs(A)
1  1
0  0
1  1
[torch.ByteTensor of size 3x2]

> A[sel]
-0.0047
-0.7976
-0.9713
-0.9660
[torch.DoubleTensor of size 4]

> A[sel]:unfold(1, 2, 2)
-0.0047 -0.7976
-0.9713 -0.9660
[torch.DoubleTensor of size 2x2]

答案 1 :(得分:0)

有两种更简单的选择:

  1. 使用maskedSelect

    result=A:maskedSelect(your_byte_tensor)

  2. 使用简单的逐元素乘法,例如

    result=torch.cmul(A,S:gt(0))

  3. 如果你需要保持原始矩阵的形状( A),第二个非常有用,例如在backprop中选择层中的神经元。但是,由于在ByteTensor规定的条件不适用时,它会在结果矩阵中放置零,因此您无法使用它来计算乘积(或中位数等)。第一个只返回满足条件的元素,所以这是我用来计算产品或中位数或任何其他我不想要零的东西。

相关问题