我知道像这样索引张量后如何更新张量:
import torch
b = torch.tensor([0, 1, 0, 1], dtype=torch.uint8)
b[b] = 2
b
# tensor([0, 2, 0, 2], dtype=torch.uint8)
但是在索引两次之后,有什么方法可以更新原始张量吗?例如
i = 1
b = torch.tensor([0, 1, 0, 1], dtype=torch.uint8)
b[b][i] = 2
b
# tensor([0, 1, 0, 1], dtype=torch.uint8)
我想让b
最终成为tensor([0, 1, 0, 2])
。有办法吗?
我知道我能做到
masked = b[b]
masked[i] = 2
b[b] = masked
b
# tensor([0, 1, 0, 2], dtype=torch.uint8)
但是还有什么更好的方法吗?看来这一定是低效的;如果masked
很大,那么我实际上只更改了一个位置,就更新了b
中的许多位置。
(如果使用不同于两次索引的方法更好的方法,我遇到的一般问题是如何在该张量的被掩盖版本的第i
个位置处更改原始张量中的值。)< / p>
答案 0 :(得分:1)
我采用了here中的另一个解决方案,并将其与您的解决方案进行了比较:
解决方案:
b[b.nonzero()[i]] = 2
运行时比较:
import torch as t
import numpy as np
import timeit
if __name__ == "__main__":
np.random.seed(12345)
b = t.tensor(np.random.randint(0,2, [1000]), dtype=t.uint8)
# inconvenient way to think of a random index halfway that is 1.
halfway = np.array(list(range(len(b))))[b == 1][len(b[b == 1]) //2]
runs = 100000
elapsed1 = timeit.timeit("mask=b[b]; mask[halfway] = 2; b[b] = mask",
"from __main__ import b, halfway", number=runs)
print("Time taken (original): {:.6f} ms per call".format(elapsed1 / runs))
elapsed2 = timeit.timeit("b[b.nonzero()[halfway]]=2",
"from __main__ import b, halfway", number=runs)
print("Time taken (improved): {:.6f} ms per call".format(elapsed2 / runs))
结果:
Time taken (original): 0.000096 ms per call
Time taken (improved): 0.000047 ms per call
长度为100000
的向量的结果
Time taken: 0.010284 ms per call
Time taken: 0.003667 ms per call
因此,解决方案仅相差2倍。我不确定这是否是最佳解决方案,但是取决于您的大小(以及调用该函数的频率),它应该使您大致了解自己的情况看着。