我正在研究一个网络,该网络需要在每次迭代时从数据集加载图像及其地面实况数据。目前,我使用以下策略,每次从我的SSD读取图像及其基本事实。问题是我的图像是* png,大尺寸(512x512x3)
,因此,阅读这些图像需要时间。此外,该策略重复读取数据集中的图像。我们是否有更好的策略来更快地完成它,即。将所有图像加载到RAM并从RAM访问它更快?我正在使用python3和keras
class myDataset(Dataset):
def load_dataset(self, root_path):
#Return a list of image and ground-truth
def load_image(self, image_id):
info = self.image_info[image_id]
image = skimage.io.imread(info["path"])
return image
def load_gt(self, image_id):
#Load ground-truth data
info = self.image_info[image_id]
gt_data = skimage.io.imread(info["path"])
return gt_data
#==========Preparing dataset, return a dict of image, mask path=====
dataset_train = myDataset()
dataset_train.load_dataset('./dataset')
dataset_train.prepare()
#==========Load image and its mask during training=================
for i in range (10000):
#Let image_id be a random id from the dict
image=dataset_train.load_image(image_id)
gt_data=dataset_train.load_gt(image_id)
答案 0 :(得分:1)
如果您能够将所有数据加载到内存中,那么您应该定义这样做。如果你也可以在开始训练之前进行预处理,那就更好了,并且定义为最快的选择。(只是看到,你已经知道了这一点)
如果不是这种情况:在培训之前批处理预处理,并保存使用pickle序列化的数据。加载pickle文件应该快得多。还要考虑使用model.fit_generator而不是多次调用fit方法。您可以编写自己的Generator类,它只加载部分图片。这个类可以从python Sequence类继承。这在pythons功能模型API的fit_generator函数中有记录,是一种非常优雅的方法。