如何使用PyTorch(或Keras Tensorflow)加载预先批处理的数据?

时间:2019-12-04 11:51:28

标签: python keras conv-neural-network pytorch sequence

我目前正在开发一个卷积神经网络,该网络可以摄取图像并对其进行回归分析。示例训练数据集可能包含40,000张图像,而验证数据集将包含20,000张图像。

为避免将它们全部加载到内存中并遇到OOM问题,我将数据预先批处理成500个图像文件的批处理,每个文件均采用.h5文件格式-生成了80个.h5文件用于训练,40个用于验证。培训和验证文件保存在其自己的目录中,例如数据\培训和数据\验证。

https://stanford.edu/~shervine/blog/pytorch-how-to-generate-data-parallel中推荐的自定义数据加载器方法如下:

class Dataset(data.Dataset):
    'Characterises a dataset for PyTorch'
     def __init__(self, list_IDs, labels):
        'Initialisation'
         self.labels = labels # targets
         self.list_IDs = list_IDs # h5 files

    def __len__(self):
        'Denotes number of samples'
        return len(self.list_IDs)*500

    def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        ID = self.list_IDs[index]

        X = torch.load('data/' + ID + '.h5')
        y = self.labels[ID]

        # Load data and get label
        return X, y

我的问题是,如何修改上面的内容以遍历每个相关的.h5文件,加载500个图像块,然后从该块向神经网络提供batch_size数量的图像继续训练。例如,将500张图片的图片加载到内存中,从该图片中选择两张图片(如果为batch_size = 2,然后通过生成器的形式将它们传递到PyTorch工作流程甚至keras fit_generator函数中(如我我在两个框架中都尝试过吗?)。

我只是不太了解它:我是否需要在两个单独的序列上使用__getitem__方法索引?一个是包含图像的.h5文件,第二个是样本总数(即40,000)。我曾短暂考虑过在方法中循环以遍历.h5文件,但我认为这会中断生成器对方法的调用。

到目前为止,我已经尝试并且有效的方法如下:

def generate_batches_from_h5_file(files, batchsize, targets):
    while True:
        with h5py.File(file, 'r') as f:
            filesize = len(f['images']) # load the 500 images
            n_entries = 0

            while n_entries < (filesize - batchsize):
                xs = np.array(f['images'][n_entries : n_entries + batchsize])
                IDs = np.array(f['IDs'][n_entries : n_entries + batchsize])
                values = [targets[ID] for ID in IDs]
                ys = np.vstack(values)
                n_entries += batchsize

                yield (xs, ys)

但是,这不是线程安全的,如果我要启用多处理功能,将无法正常工作。

对此,我将不胜感激,谢谢。

0 个答案:

没有答案