我有一个包含大量图像的大数据集,我发现数据加载器的速度非常慢。 我做了很多测试,发现图片数量很大:
这是示例代码:
class MyTestDataset(torchvision.datasets.vision.VisionDataset):
def __init__(self,
transform=None,
target_transform=None):
super(MyTestDataset, self).__init__(None, transform=transform,
target_transform=target_transform)
def __getitem__(self, index):
time.sleep(0.001)
return np.random.rand(256, 256),random.randint(0, 300)
def __len__(self):
return 17250000
dataset = MyTestDataset()
since = time.time()
imgs = []
for idx in range(128):
img,label = dataset[10000+idx*1000]
imgs.append(img)
base = time.time() - since
print("next dataset",base)
loader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=False, num_workers=0)
since = time.time()
imgs = next(iter(loader))
cost = time.time() - since
print("loader no shuffle num_workers-0",cost,cost/base)
loader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=False, num_workers=8)
since = time.time()
imgs = next(iter(loader))
cost = time.time() - since
print("loader no shuffle num_workers-8",cost,cost/base)
loader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=True, num_workers=0)
since = time.time()
imgs = next(iter(loader))
cost = time.time() - since
print("loader shuffle num_workers-0",cost,cost/base)
dataset = MyTestDataset()
loader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=True, num_workers=8)
since = time.time()
imgs = next(iter(loader))
cost = time.time() - since
print("loader shuffle num_workers-8",cost,cost/base)
下面是输出
next dataset 0.2839939594268799
loader no shuffle num_workers-0 0.3133578300476074 1.103396109832709
loader no shuffle num_workers-8 0.811976432800293 2.859132759157693
loader shuffle num_workers-0 2.041795253753662 7.189572827091643
loader shuffle num_workers-8 2.617912769317627 9.218198776483705
似乎随机播放使它慢了9倍?