我正在使用torch.utils.data.Dataloader
加载我的数据。当dict
通过dataloader
对象中的__getitem__
函数加载数据时,我将数据存储在全局python torch.utils.data.Dataset
变量中。但是在每次训练之后,它释放了所有的记忆。然后它再次从硬盘加载了所有数据。
# in data.py file
from torch.utils.data import Dataset, DataLoader
PRELOAD_DATA = dict()
class Data(Dataset):
def __init__(self, some_path):
self.file_list = some_path # a list contains all the file path
def __len__(self):
return len(self.file_list)
def __getitem__(self, index):
if index in PRELOAD_DATA:
this_blob = PRELOAD_DATA[index]
else:
this_blob = load_data_from_disk(self.file_list[index])
PRELOAD_DATA[index] = this_blob # I try to store data in this line
return this_blob
def multithread_dataloader(some_path):
this_data = Data(some_path)
this_dataloader = DataLoader(this_data, batch_size=1, shuffle=True, num_workers=8, drop_last=False)
return this_dataloader
# in main.py file
from data import multithread_dataloader
file_path_list = ... # a list contains all the file path
train_data = multithread_dataloader(file_path_list)
for epoch in range(100):
for blob in train_data:
# training