如何根据pytorch中另一个张量的值将张量的某个值更改为零?

时间:2020-05-25 12:23:22

标签: python pytorch

我有两个张量:张量a和张量b。如何根据张量b的值更改张量a的某些值?

我知道以下代码是正确的,但是当张量很大时,它的运行速度很慢。还有其他方法吗?

import torch
a = torch.rand(10).cuda()
b = torch.rand(10).cuda()
a[b > 0.5] = 0.

2 个答案:

答案 0 :(得分:1)

我想/bin/sh会更快,这是我在CPU中进行的测量。

torch.where
import torch
a = torch.rand(3**10)
b = torch.rand(3**10)
%timeit a[b > 0.5] = 0.
852 µs ± 30.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

答案 1 :(得分:1)

对于这个确切的用例,请考虑

a * (b <= 0.5)

这似乎是以下最快的

In [1]: import torch
   ...: a = torch.rand(3**10)
   ...: b = torch.rand(3**10)

In [2]: %timeit a[b > 0.5] = 0.
553 µs ± 17.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

In [3]: a = torch.rand(3**10)

In [4]: %timeit temp = torch.where(b > 0.5, torch.tensor(0.), a)
   ...:
49 µs ± 391 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

In [5]: a = torch.rand(3**10)

In [6]: %timeit temp = (a * (b <= 0.5))
44 µs ± 381 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

In [7]: %timeit a.masked_fill_(b > 0.5, 0.)
244 µs ± 3.48 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)