该代码在我的操作系统上运行正常,但是对于Google Colab,在由数据加载器调用时,它不会加载hdf5数据。
这是我的数据集:
class SkyDataset(Dataset):
'''
Cette classe est un générateur qui extrait les images ensoleillées et
nuageuses dans des ensembles disctints.
Les images en sortie font 256X256X3 et ont des valeurs entre 0 et 1.
'''
def __init__(self, h5_path, load_ratio=1., transform=None):
super(SkyDataset, self).__init__()
self.transform = transform
self.h5_path = h5_path
self.cloud_args, self.sun_args, whole_data_len = self.__get_dataset_info(h5_path)
self.data_len = int(load_ratio*whole_data_len)
print('dataset len: {}'.format(self.data_len))
def __getitem__(self, index):
# Get inputs data:
# inputs: L, r, dt, Vtheta[-1], Vtheta[-2], Vtheta[-2], phi[-1], phi[-2], phi[-3]
cloud_cam, cloud_arg = self.cloud_args[index]
sun_cam, sun_arg = self.sun_args[index]
with h5py.File(self.h5_path, 'r') as file:
cloud_images = file[cloud_cam]['images'][cloud_arg]/255
print(cloud_images)
cloud_labels = file[cloud_cam]['labels'][cloud_arg]
sun_images = file[sun_cam]['images'][sun_arg]/255
sun_labels = file[sun_cam]['labels'][sun_arg]
print('im_dim: {}'.format(cloud_images.shape))
if self.transform:
cloud_images = self.transform(cloud_images)
sun_images = self.transform(sun_images)
return cloud_images, sun_images, cloud_labels.astype(int), sun_labels.astype(int)
def __len__(self):
return self.data_len
def __get_dataset_info(self, h5_path):
'''
:param h5_path:
:return: iter_args_cloud, iter_args_sun, min_length
iter_args_cloud: (# cam, # argument) pour toutes les images nuageuses dans dataset
iter_args_sun: (# cam, # argument) pour toutes les images ensoleillées dans dataset
min_length: Le nombre minimal entre les images ensoleillées et nuageuses
'''
with h5py.File(self.h5_path, 'r') as file:
iter_args_cloud = [(cam, arg[0]) for cam in file for arg in np.argwhere(np.array(file[cam]['labels']) == 0)]
iter_args_sun = [(cam, arg[0]) for cam in file for arg in np.argwhere(np.array(file[cam]['labels']) == 1)]
min_length = min(len(iter_args_cloud), len(iter_args_sun))
return iter_args_cloud, iter_args_sun, min_length
我的数据加载器:
def getloader(dataset_name, batch_size, data_ratio=1., shuffle=True, num_workers=4, use_gpu=False):
if dataset_name != 'train' and dataset_name != 'test':
raise Exception("Le paramètre 'dataset' doit prendre la valeur 'train' ou 'test'")
if use_gpu:
pin_memory = True
else:
pin_memory = False
train_path = '..//data//train_data.hdf5'
test_path = '..//data//test_data.hdf5'
print('exist: {}'.format(os.path.exists(train_path)))
if dataset_name == 'train':
path = train_path
elif dataset_name == 'test':
path = test_path
dataset = SkyDataset(path, load_ratio=data_ratio)
print(' - {} (merged) dataset loaded - '.format(dataset_name))
transforms_ = transforms.Compose([RandomHorizontalFlip(),
ToTensor(),
transforms.Normalize((0.515, 0.525, 0.532), (0.22, 0.23, 0.265))])
dataset.transform = transforms_
# Les batch_size sur test et val ne servent qu'à la performance lors de l'inférence sur GPU
params = {'batch_size': batch_size, 'shuffle': shuffle, 'pin_memory': pin_memory, 'num_workers': num_workers}
loader = torch.utils.data.DataLoader(dataset, **params)
return loader
class ToTensor(object):
def __call__(self, image):
image = image.transpose((2,1,0))
return torch.from_numpy(image.copy())
class RandomHorizontalFlip():
def __call__(self, image):
if np.random.rand() > 0.5:
image = np.flip(image, axis=1)
else:
pass
return image
我得到这个错误:
0.9372549019607843
im_dim: ()
0.9254901960784314
im_dim: ()
0.8235294117647058
im_dim: ()
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-9-96cc2f68b8b5> in <module>()
103
104 if __name__ == '__main__':
--> 105 main()
<ipython-input-9-96cc2f68b8b5> in main()
98 # Appel de l'entrainement
99 nets = train(nets, optimizers, criterions, train_loader,
val_split=val_split, lambda_cyc=lambda_cyc,
--> 100 n_epoch=1, use_gpu=use_gpu,
schedulers=lr_schedulers)
101
102 #test(nets, test_loader, use_gpu=use_gpu)
/content/projet_ml/CycleGAN/utils/trainning.py in train(nets, optimizers,
criterions, train_loader, val_split, lambda_cyc, n_epoch, use_gpu,
schedulers)
27 for epoch in range(n_epoch):
28 start = time.time()
---> 29 do_epoch(nets, optimizers, criterions, train_loader,
val_split, lambda_cyc, use_gpu, schedulers)
30 end = time.time()
31
/content/projet_ml/CycleGAN/utils/trainning.py in do_epoch(nets, optimizers,
criterions, train_loader, val_split, lambda_cyc, use_gpu, schedulers)
50
51 train_loader_length = len(train_loader)
---> 52 for idx, data in enumerate(train_loader):
53 if idx > train_loader_length * (1 - val_split):
54 # Le reste des données appartiennent à validation
/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in
__next__(self)
334 self.reorder_dict[idx] = batch
335 continue
--> 336 return self._process_next_batch(batch)
337
338 next = __next__ # Python 2 compatibility
/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in
_process_next_batch(self, batch)
355 self._put_indices()
356 if isinstance(batch, ExceptionWrapper):
--> 357 raise batch.exc_type(batch.exc_msg)
358 return batch
359
ValueError: Traceback (most recent call last):
File "/usr/local/lib/python3.6/dist-packages/numpy/lib/function_base.py",
line 206, in flip
indexer[axis] = slice(None, None, -1)
IndexError: list assignment index out of range
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/usr/local/lib/python3.6/dist-
packages/torch/utils/data/dataloader.py", line 106, in _worker_loop
samples = collate_fn([dataset[i] for i in batch_indices])
File "/usr/local/lib/python3.6/dist-
packages/torch/utils/data/dataloader.py", line 106, in <listcomp>
samples = collate_fn([dataset[i] for i in batch_indices])
File "/content/projet_ml/CycleGAN/utils/Dataset.py", line 35, in
__getitem__
cloud_images = self.transform(cloud_images)
File "/usr/local/lib/python3.6/dist-
packages/torchvision/transforms/transforms.py", line 49, in __call__
img = t(img)
File "/content/projet_ml/CycleGAN/utils/Loaders.py", line 66, in __call__
image = np.flip(image, axis=1)
File "/usr/local/lib/python3.6/dist-packages/numpy/lib/function_base.py",
line 209, in flip
% (axis, m.ndim))
ValueError: axis=1 is invalid for the 0-dimensional input array
当然,由于我的图片是空的,所以我得到了ValueError(请看im_dim:())。好像__getitem__不能完成工作,但我做不到的更多。我也尝试了没有h5py上下文管理器(file.close())。如果有人可以帮助我,我将不胜感激!
谢谢!