Dataloader不会使用Google Colab在Pytorch中加载图片

时间:2018-11-10 17:45:59

标签: dataset pytorch google-colaboratory h5py

该代码在我的操作系统上运行正常,但是对于Google Colab,在由数据加载器调用时,它不会加载hdf5数据。

这是我的数据集:

class SkyDataset(Dataset):
'''
    Cette classe est un générateur qui extrait les images ensoleillées et
    nuageuses dans des ensembles disctints.
    Les images en sortie font 256X256X3 et ont des valeurs entre 0 et 1.
'''
def __init__(self, h5_path, load_ratio=1., transform=None):
    super(SkyDataset, self).__init__()
    self.transform = transform
    self.h5_path = h5_path
    self.cloud_args, self.sun_args, whole_data_len = self.__get_dataset_info(h5_path)
    self.data_len = int(load_ratio*whole_data_len)
    print('dataset len: {}'.format(self.data_len))

def __getitem__(self, index):
    # Get inputs data:
    # inputs: L, r, dt, Vtheta[-1], Vtheta[-2], Vtheta[-2], phi[-1], phi[-2], phi[-3]

    cloud_cam, cloud_arg = self.cloud_args[index]
    sun_cam, sun_arg = self.sun_args[index]
    with h5py.File(self.h5_path, 'r') as file:
        cloud_images = file[cloud_cam]['images'][cloud_arg]/255
        print(cloud_images)
        cloud_labels = file[cloud_cam]['labels'][cloud_arg]
        sun_images = file[sun_cam]['images'][sun_arg]/255
        sun_labels = file[sun_cam]['labels'][sun_arg]
    print('im_dim: {}'.format(cloud_images.shape))

    if self.transform:
        cloud_images = self.transform(cloud_images)
        sun_images = self.transform(sun_images)

    return cloud_images, sun_images, cloud_labels.astype(int), sun_labels.astype(int)

def __len__(self):
    return self.data_len

def __get_dataset_info(self, h5_path):
    '''

    :param h5_path:
    :return: iter_args_cloud, iter_args_sun, min_length

    iter_args_cloud: (# cam, # argument) pour toutes les images nuageuses dans dataset
    iter_args_sun: (# cam, # argument) pour toutes les images ensoleillées dans dataset
    min_length: Le nombre minimal entre les images ensoleillées et nuageuses
    '''
    with h5py.File(self.h5_path, 'r') as file:
        iter_args_cloud = [(cam, arg[0]) for cam in file for arg in np.argwhere(np.array(file[cam]['labels']) == 0)]
        iter_args_sun = [(cam, arg[0]) for cam in file for arg in np.argwhere(np.array(file[cam]['labels']) == 1)]
    min_length = min(len(iter_args_cloud), len(iter_args_sun))
    return iter_args_cloud, iter_args_sun, min_length

我的数据加载器:

def getloader(dataset_name, batch_size, data_ratio=1., shuffle=True, num_workers=4, use_gpu=False):

if dataset_name != 'train' and dataset_name != 'test':
    raise Exception("Le paramètre 'dataset' doit prendre la valeur 'train' ou 'test'")

if use_gpu:
    pin_memory = True
else:
    pin_memory = False

train_path = '..//data//train_data.hdf5'
test_path = '..//data//test_data.hdf5'
print('exist: {}'.format(os.path.exists(train_path)))

if dataset_name == 'train':
    path = train_path
elif dataset_name == 'test':
    path = test_path

dataset = SkyDataset(path, load_ratio=data_ratio)
print(' - {} (merged) dataset loaded - '.format(dataset_name))

transforms_ = transforms.Compose([RandomHorizontalFlip(),
                                  ToTensor(),
                                  transforms.Normalize((0.515, 0.525, 0.532), (0.22, 0.23, 0.265))])

dataset.transform = transforms_

# Les batch_size sur test et val ne servent qu'à la performance lors de l'inférence sur GPU
params = {'batch_size': batch_size, 'shuffle': shuffle, 'pin_memory': pin_memory, 'num_workers': num_workers}
loader = torch.utils.data.DataLoader(dataset, **params)

return loader

class ToTensor(object):

    def __call__(self, image):
        image = image.transpose((2,1,0))
        return torch.from_numpy(image.copy())

class RandomHorizontalFlip():

    def __call__(self, image):
        if np.random.rand() > 0.5:
            image = np.flip(image, axis=1)
        else:
            pass
        return image

我得到这个错误:

0.9372549019607843
im_dim: ()
0.9254901960784314
im_dim: ()
0.8235294117647058
im_dim: ()
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-9-96cc2f68b8b5> in <module>()
    103 
    104 if __name__ == '__main__':
--> 105     main()

<ipython-input-9-96cc2f68b8b5> in main()
     98     # Appel de l'entrainement
     99     nets = train(nets, optimizers, criterions, train_loader,                       
 val_split=val_split, lambda_cyc=lambda_cyc,
 --> 100                  n_epoch=1, use_gpu=use_gpu,                     
 schedulers=lr_schedulers)
101 
102     #test(nets, test_loader, use_gpu=use_gpu)

/content/projet_ml/CycleGAN/utils/trainning.py in train(nets, optimizers,         
criterions, train_loader, val_split, lambda_cyc, n_epoch, use_gpu, 
schedulers)
     27     for epoch in range(n_epoch):
     28         start = time.time()
---> 29         do_epoch(nets, optimizers, criterions, train_loader,     
val_split, lambda_cyc, use_gpu, schedulers)
     30         end = time.time()
     31 

/content/projet_ml/CycleGAN/utils/trainning.py in do_epoch(nets, optimizers,     
criterions, train_loader, val_split, lambda_cyc, use_gpu, schedulers)
     50 
     51     train_loader_length = len(train_loader)
---> 52     for idx, data in enumerate(train_loader):
     53         if idx > train_loader_length * (1 - val_split):
     54             # Le reste des données appartiennent à validation

/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in     
__next__(self)
    334                 self.reorder_dict[idx] = batch
    335                 continue
--> 336             return self._process_next_batch(batch)
    337 
    338     next = __next__  # Python 2 compatibility

/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in     
_process_next_batch(self, batch)
    355         self._put_indices()
    356         if isinstance(batch, ExceptionWrapper):
--> 357             raise batch.exc_type(batch.exc_msg)
    358         return batch
    359 

ValueError: Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/numpy/lib/function_base.py",     
line 206, in flip
    indexer[axis] = slice(None, None, -1)
IndexError: list assignment index out of range

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-    
packages/torch/utils/data/dataloader.py", line 106, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/usr/local/lib/python3.6/dist-    
packages/torch/utils/data/dataloader.py", line 106, in <listcomp>
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/content/projet_ml/CycleGAN/utils/Dataset.py", line 35, in     
__getitem__
    cloud_images = self.transform(cloud_images)
  File "/usr/local/lib/python3.6/dist-    
packages/torchvision/transforms/transforms.py", line 49, in __call__
    img = t(img)
  File "/content/projet_ml/CycleGAN/utils/Loaders.py", line 66, in __call__
    image = np.flip(image, axis=1)
  File "/usr/local/lib/python3.6/dist-packages/numpy/lib/function_base.py",     
line 209, in flip
    % (axis, m.ndim))
ValueError: axis=1 is invalid for the 0-dimensional input array

当然,由于我的图片是空的,所以我得到了ValueError(请看im_dim:())。好像__getitem__不能完成工作,但我做不到的更多。我也尝试了没有h5py上下文管理器(file.close())。如果有人可以帮助我,我将不胜感激!

谢谢!

0 个答案:

没有答案