使用torch.utils.data.DataLoder时如何将加载的数据保留在RAM中?

时间:2019-04-25 08:10:15

标签: python pytorch

我正在使用torch.utils.data.Dataloader加载我的数据。当dict通过dataloader对象中的__getitem__函数加载数据时,我将数据存储在全局python torch.utils.data.Dataset变量中。但是在每次训练之后,它释放了所有的记忆。然后它再次从硬盘加载了所有数据。

# in data.py file
from torch.utils.data import Dataset, DataLoader

PRELOAD_DATA = dict()

class Data(Dataset):
    def __init__(self, some_path):
        self.file_list = some_path  # a list contains all the file path

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, index):
        if index in PRELOAD_DATA:
            this_blob = PRELOAD_DATA[index]
        else:
            this_blob = load_data_from_disk(self.file_list[index])
            PRELOAD_DATA[index] = this_blob  # I try to store data in this line
        return this_blob

def multithread_dataloader(some_path):
    this_data = Data(some_path)
    this_dataloader = DataLoader(this_data, batch_size=1, shuffle=True, num_workers=8, drop_last=False)
    return this_dataloader
# in main.py file
from data import multithread_dataloader

file_path_list = ...  # a list contains all the file path
train_data = multithread_dataloader(file_path_list)
for epoch in range(100):
    for blob in train_data:
        # training

0 个答案:

没有答案