InvalidArgumentError:第3维的切片索引1超出范围。在细分中

时间:2019-11-19 08:34:06

标签: python python-3.x tensorflow keras image-segmentation

我正在从下面的位置获取Unet代码

https://github.com/JooHyun-Lee/BraTs/blob/master/keras/unet.py

args = {'image_root':"/media/kriti/Data_Drive/brain_mri_work/BRATS2017_preprocessed/HGG/data",
        'label_root':"/media/kriti/Data_Drive/brain_mri_work/BRATS2017_preprocessed/HGG/data",
        'image_folder1':"T1a",
        'label_folder1':"label1",
        'image_folder2':"T1b",
        'label_folder2':"label2",       
        'epoch':10,
        'lr':0.0001,
        'batch_size':4,
        'ckpt_path':'unet.hdf5'}

from keras.preprocessing.image import ImageDataGenerator

def dataset(args, mode='train',
            image_color_mode = "grayscale", label_color_mode = "grayscale",
            image_save_prefix  = "image", label_save_prefix  = "label",
            save_to_dir = None, target_size = (256,256), seed = 1):

    ''' Prepare dataset ( pre-processing + augmentation(optional) )
    Args:
        args (argparse):          Arguments parsered in command-lind
        mode (str):               Mode ('train', 'valid', 'test')
        image_color_mode (str):   Image color Mode Flag
        label_color_mode (str):   Label color Mode Flag
        image_save_prefix (str):  Prefix to use for filnames of saved images
        label_save_prefix (str):  Prefix to use for filename of saved labels
        save_to_dir (str):        Save directory
        target_size (tuple):      Target Size
        seed (int):               Seed value
    '''

    # Data Augmentation
    if mode == 'train':
        shuffle=True
        image_datagen = ImageDataGenerator(rotation_range=20,
                                           horizontal_flip=True,
                                           vertical_flip=True,
                                           width_shift_range=0.1,
                                           height_shift_range=0.1,
                                           shear_range=0.2,
                                           zoom_range=0.1)
                                           #brightness_range=[0.8,1.2])

        label_datagen = ImageDataGenerator(rotation_range=20,
                                           horizontal_flip=True,
                                           vertical_flip=True,
                                           width_shift_range=0.1,
                                           height_shift_range=0.1,
                                           shear_range=0.2,
                                           zoom_range=0.1)
    elif mode == 'test' or mode == 'valid':
        shuffle=False
        image_datagen = ImageDataGenerator()
        label_datagen = ImageDataGenerator()
    else:
        raise ValueError('dataset mode ERROR!')

    image_generator1 = image_datagen.flow_from_directory(
        args['image_root'],
        classes = [args['image_folder1']],
        class_mode = None,
        color_mode = image_color_mode,
        target_size = target_size,
        batch_size = args['batch_size'],
        save_to_dir = save_to_dir,
        save_prefix  = image_save_prefix,
        shuffle = shuffle,
        seed = seed)

    label_generator1 = label_datagen.flow_from_directory(
        args['label_root'],
        classes = [args['label_folder1']],
        class_mode = None,
        color_mode = label_color_mode,
        target_size = target_size,
        batch_size = args['batch_size'],
        save_to_dir = save_to_dir,
        save_prefix  = label_save_prefix,
        shuffle = shuffle,
        seed = seed)

    data_generator1 = zip(image_generator1, label_generator1)
    cnt = 0

    if mode == 'test' or mode == 'valid':
        for (img,label) in data_generator1:
#             img,label = adjustData(img, label, args.data, cnt)
            cnt+=img.shape[0]
            yield (img,label)

    else:
        image_generator2 = image_datagen.flow_from_directory(
            args['image_root'],
            classes = [args['image_folder2']],
            class_mode = None,
            color_mode = image_color_mode,
            target_size = target_size,
            batch_size = args['batch_size'],
            save_to_dir = save_to_dir,
            save_prefix  = image_save_prefix,
            shuffle = shuffle,
            seed = seed+1)

        label_generator2 = label_datagen.flow_from_directory(
            args['label_root'],
            classes = [args['label_folder2']],
            class_mode = None,
            color_mode = label_color_mode,
            target_size = target_size,
            batch_size = args['batch_size'],
            save_to_dir = save_to_dir,
            save_prefix  = label_save_prefix,
            shuffle = shuffle,
            seed = seed+1)
        data_generator2 = zip(image_generator2, label_generator2)

        while(True):
            if np.random.randint(3) == 2:
                for (img,label) in data_generator1:
#                     img,label = adjustData(img, label, args.data, cnt, 'F')
                    cnt+=img.shape[0]
                    yield (img,label)

            else:
                for (img,label) in data_generator2:
#                     img,label = adjustData(img, label, args.data, cnt, 'S')
                    cnt+=img.shape[0]
                    yield (img,label)

trainset = dataset(args, mode='train')
validset = dataset(args, mode='valid')

# next(validset)[0].shape --> (4,256,256,1)
# next(validset)[1].shape --> (4,256,256,1)


model = unet()

model.fit_generator(trainset, steps_per_epoch=500, shuffle=True, epochs=args['epoch'],validation_data=validset, validation_steps=2000)

使用fit_generator时出现此错误。手动堆叠X和y而不是使用fit_generator时,我会遇到相同的错误。

目标是二元类分割,其中X和y均为灰度图像。错误的回溯看起来像这样

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-25-b1490d73b6ae> in <module>
----> 1 model.fit_generator(trainset, steps_per_epoch=500, shuffle=True, epochs=args['epoch'],validation_data=validset, validation_steps=2000)

~/anaconda3/lib/python3.7/site-packages/keras/legacy/interfaces.py in wrapper(*args, **kwargs)
     89                 warnings.warn('Update your `' + object_name + '` call to the ' +
     90                               'Keras 2 API: ' + signature, stacklevel=2)
---> 91             return func(*args, **kwargs)
     92         wrapper._original_function = func
     93         return wrapper

~/anaconda3/lib/python3.7/site-packages/keras/engine/training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
   1416             use_multiprocessing=use_multiprocessing,
   1417             shuffle=shuffle,
-> 1418             initial_epoch=initial_epoch)
   1419 
   1420     @interfaces.legacy_generator_methods_support

~/anaconda3/lib/python3.7/site-packages/keras/engine/training_generator.py in fit_generator(model, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
    215                 outs = model.train_on_batch(x, y,
    216                                             sample_weight=sample_weight,
--> 217                                             class_weight=class_weight)
    218 
    219                 outs = to_list(outs)

~/anaconda3/lib/python3.7/site-packages/keras/engine/training.py in train_on_batch(self, x, y, sample_weight, class_weight)
   1215             ins = x + y + sample_weights
   1216         self._make_train_function()
-> 1217         outputs = self.train_function(ins)
   1218         return unpack_singleton(outputs)
   1219 

~/anaconda3/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py in __call__(self, inputs)
   2713                 return self._legacy_call(inputs)
   2714 
-> 2715             return self._call(inputs)
   2716         else:
   2717             if py_any(is_tensor(x) for x in inputs):

~/anaconda3/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py in _call(self, inputs)
   2673             fetched = self._callable_fn(*array_vals, run_metadata=self.run_metadata)
   2674         else:
-> 2675             fetched = self._callable_fn(*array_vals)
   2676         return fetched[:len(self.outputs)]
   2677 

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/client/session.py in __call__(self, *args, **kwargs)
   1456         ret = tf_session.TF_SessionRunCallable(self._session._session,
   1457                                                self._handle, args,
-> 1458                                                run_metadata_ptr)
   1459         if run_metadata:
   1460           proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

InvalidArgumentError: slice index 1 of dimension 3 out of bounds.
     [[{{node loss/activation_1_loss/strided_slice_2}}]]

我还需要知道ImageDataGenerator的培训和验证目录应该如何工作,因为这不是分类问题。

0 个答案:

没有答案