如何迭代 PyTorch 张量

时间:2021-05-19 01:59:19

标签: pytorch iteration tensor

我有一个大小为 (1000,110) 的张量数据,我想遍历张量的第一个索引并计算以下内容。

    data = torch.randn(size=(1000,110)).to(device)
    
    male_poor = torch.tensor(0).float().to(device)
    male_rich = torch.tensor(0).float().to(device)
    
    female_poor = torch.tensor(0).float().to(device)
    female_rich = torch.tensor(0).float().to(device)
    
    for i in data:
    
        if torch.argmax(i[64:66]) == 0 and torch.argmax(i[108:110]) == 0:
          female_poor += 1
        if torch.argmax(i[64:66]) == 0 and torch.argmax(i[108:110]) == 1:
          female_rich += 1
        if torch.argmax(i[64:66]) == 1 and torch.argmax(i[108:110]) == 0:
          male_poor += 1
        if torch.argmax(i[64:66]) == 1 and torch.argmax(i[108:110]) == 1:
          male_rich += 1


    disparity = ((female_rich/(female_rich + female_poor))) / ((male_rich/(male_rich + male_poor)))

有没有比 for 循环更快的方法来做到这一点?

1 个答案:

答案 0 :(得分:2)

pytorch(以及 numpy)的关键是矢量化,也就是说,如果您可以通过对矩阵进行操作来删除循环,速度会快很多。与底层编译的 C 代码中的循环相比,python 中的循环相当慢。在我的机器上,您的代码的执行时间大约为 0.091 秒,以下矢量化代码大约为 0.002 秒,因此大约快了 50 倍:

import torch
torch.manual_seed(0)
device = torch.device('cpu')

data = torch.randn(size=(1000, 110)).to(device)

import time
t = time.time()
#vectorize over first dimension
argmax64_0 = torch.argmax(data[:, 64:66], dim=1) == 0
argmax64_1 = torch.argmax(data[:, 64:66], dim=1) == 1
argmax108_0 = torch.argmax(data[:, 108:110], dim=1) == 0
argmax108_1 = torch.argmax(data[:, 108:110], dim=1) == 1
female_poor = (argmax64_0 & argmax108_0).sum()
female_rich = (argmax64_0 & argmax108_1).sum()
male_poor = (argmax64_1 & argmax108_0).sum()
male_rich = (argmax64_1 & argmax108_1).sum()

disparity = ((female_rich / (female_rich + female_poor))) / ((male_rich / (male_rich + male_poor)))

print(time.time()-t)
print(disparity)
相关问题