如何在pytorch上加载预先标记的数据集

时间:2020-08-18 22:22:50

标签: pytorch h5py

我有一个庞大的数据集,无法存储在内存中,因此我将其预批处理了几个文件,如何使我的数据集和数据加载器类一次加载一个浴池。

  • 所有文件都具有相同的基本名称和唯一的批处理编号
  • 示例文件将称为o3_batch_1.hdf5或o3_batch_2.hdf5。
  • 最大批号为o3_batch_102.hdf5

这是我到目前为止尝试过的:

行得通吗? length将是数据的总长度。

batchNum是文件末尾的非唯一数字。

base是文件共享的通用名称。

类数据(数据集):

# Constructor
def __init__(self, base, batchNum, length):
    name = base + str(batchNum) 
    with h5py.File(name, "r") as f:
        puzz = np.array(f.get('puzzle'))
        sol = np.array(f.get('Sol'))
    self.puzz = torch.from_numpy(puzz)
    self.sol = torch.from_numpy(sol)
    self.len = length
    
# Getter
def __getitem__(self, batchNum, index):    
    return self.puzz[index], self.sol[index]

# Get length
def __len__(self):
    return self.len 

0 个答案:

没有答案