索引两次后如何在Pytorch中更新张量?

时间:2019-04-09 03:39:26

标签: python pytorch

我知道像这样索引张量后如何更新张量:

import torch

b = torch.tensor([0, 1, 0, 1], dtype=torch.uint8)
b[b] = 2
b
# tensor([0, 2, 0, 2], dtype=torch.uint8)

但是在索引两次之后,有什么方法可以更新原始张量吗?例如

i = 1
b = torch.tensor([0, 1, 0, 1], dtype=torch.uint8)
b[b][i] = 2
b
# tensor([0, 1, 0, 1], dtype=torch.uint8)

我想让b最终成为tensor([0, 1, 0, 2])。有办法吗?

我知道我能做到

masked = b[b]
masked[i] = 2
b[b] = masked
b
# tensor([0, 1, 0, 2], dtype=torch.uint8)

但是还有什么更好的方法吗?看来这一定是低效的;如果masked很大,那么我实际上只更改了一个位置,就更新了b中的许多位置。

(如果使用不同于两次索引的方法更好的方法,我遇到的一般问题是如何在该张量的被掩盖版本的第i个位置处更改原始张量中的值。)< / p>

1 个答案:

答案 0 :(得分:1)

我采用了here中的另一个解决方案,并将其与您的解决方案进行了比较:

解决方案:

b[b.nonzero()[i]] = 2

运行时比较:

import torch as t
import numpy as np
import timeit


if __name__ == "__main__":

    np.random.seed(12345)
    b = t.tensor(np.random.randint(0,2, [1000]), dtype=t.uint8)
    # inconvenient way to think of a random index halfway that is 1.
    halfway = np.array(list(range(len(b))))[b == 1][len(b[b == 1]) //2]

    runs = 100000

    elapsed1 = timeit.timeit("mask=b[b]; mask[halfway] = 2; b[b] = mask", 
                             "from __main__ import b, halfway", number=runs)

    print("Time taken (original): {:.6f} ms per call".format(elapsed1 / runs))

    elapsed2 = timeit.timeit("b[b.nonzero()[halfway]]=2",
                             "from __main__ import b, halfway", number=runs)

    print("Time taken (improved): {:.6f} ms per call".format(elapsed2 / runs))

结果:

Time taken (original): 0.000096 ms per call
Time taken (improved): 0.000047 ms per call

长度为100000的向量的结果

Time taken: 0.010284 ms per call
Time taken: 0.003667 ms per call

因此,解决方案仅相差2倍。我不确定这是否是最佳解决方案,但是取决于您的大小(以及调用该函数的频率),它应该使您大致了解自己的情况看着。