删除火炬张量中的行

时间:2019-10-23 19:55:37

标签: python pytorch torch tensor

我的火炬张量如下-

a = tensor(
[[0.2215, 0.5859, 0.4782, 0.7411],
[0.3078, 0.3854, 0.3981, 0.5200],
[0.1363, 0.4060, 0.2030, 0.4940],
[0.1640, 0.6025, 0.2267, 0.7036],
[0.2445, 0.3032, 0.3300, 0.4253]],  dtype=torch.float64)

如果每行的第一个值小于0.2,则需要删除整行。因此,我需要类似-

的输出
tensor(
[[0.2215, 0.5859, 0.4782, 0.7411],
[0.3078, 0.3854, 0.3981, 0.5200],
[0.2445, 0.3032, 0.3300, 0.4253]],  dtype=torch.float64)

我尝试遍历张量并将有效值附加到新的空张量,但未成功。有什么方法可以有效地获得结果吗?

1 个答案:

答案 0 :(得分:1)

代码

a = torch.Tensor(
    [[0.2215, 0.5859, 0.4782, 0.7411],
    [0.3078, 0.3854, 0.3981, 0.5200],
    [0.1363, 0.4060, 0.2030, 0.4940],
    [0.1640, 0.6025, 0.2267, 0.7036],
    [0.2445, 0.3032, 0.3300, 0.4253]])

y = a[a[:, 0] > 0.2]
print(y)

输出

tensor([[0.2215, 0.5859, 0.4782, 0.7411],
        [0.3078, 0.3854, 0.3981, 0.5200],
        [0.2445, 0.3032, 0.3300, 0.4253]])