我正在从下面的位置获取Unet代码
https://github.com/JooHyun-Lee/BraTs/blob/master/keras/unet.py
args = {'image_root':"/media/kriti/Data_Drive/brain_mri_work/BRATS2017_preprocessed/HGG/data",
'label_root':"/media/kriti/Data_Drive/brain_mri_work/BRATS2017_preprocessed/HGG/data",
'image_folder1':"T1a",
'label_folder1':"label1",
'image_folder2':"T1b",
'label_folder2':"label2",
'epoch':10,
'lr':0.0001,
'batch_size':4,
'ckpt_path':'unet.hdf5'}
from keras.preprocessing.image import ImageDataGenerator
def dataset(args, mode='train',
image_color_mode = "grayscale", label_color_mode = "grayscale",
image_save_prefix = "image", label_save_prefix = "label",
save_to_dir = None, target_size = (256,256), seed = 1):
''' Prepare dataset ( pre-processing + augmentation(optional) )
Args:
args (argparse): Arguments parsered in command-lind
mode (str): Mode ('train', 'valid', 'test')
image_color_mode (str): Image color Mode Flag
label_color_mode (str): Label color Mode Flag
image_save_prefix (str): Prefix to use for filnames of saved images
label_save_prefix (str): Prefix to use for filename of saved labels
save_to_dir (str): Save directory
target_size (tuple): Target Size
seed (int): Seed value
'''
# Data Augmentation
if mode == 'train':
shuffle=True
image_datagen = ImageDataGenerator(rotation_range=20,
horizontal_flip=True,
vertical_flip=True,
width_shift_range=0.1,
height_shift_range=0.1,
shear_range=0.2,
zoom_range=0.1)
#brightness_range=[0.8,1.2])
label_datagen = ImageDataGenerator(rotation_range=20,
horizontal_flip=True,
vertical_flip=True,
width_shift_range=0.1,
height_shift_range=0.1,
shear_range=0.2,
zoom_range=0.1)
elif mode == 'test' or mode == 'valid':
shuffle=False
image_datagen = ImageDataGenerator()
label_datagen = ImageDataGenerator()
else:
raise ValueError('dataset mode ERROR!')
image_generator1 = image_datagen.flow_from_directory(
args['image_root'],
classes = [args['image_folder1']],
class_mode = None,
color_mode = image_color_mode,
target_size = target_size,
batch_size = args['batch_size'],
save_to_dir = save_to_dir,
save_prefix = image_save_prefix,
shuffle = shuffle,
seed = seed)
label_generator1 = label_datagen.flow_from_directory(
args['label_root'],
classes = [args['label_folder1']],
class_mode = None,
color_mode = label_color_mode,
target_size = target_size,
batch_size = args['batch_size'],
save_to_dir = save_to_dir,
save_prefix = label_save_prefix,
shuffle = shuffle,
seed = seed)
data_generator1 = zip(image_generator1, label_generator1)
cnt = 0
if mode == 'test' or mode == 'valid':
for (img,label) in data_generator1:
# img,label = adjustData(img, label, args.data, cnt)
cnt+=img.shape[0]
yield (img,label)
else:
image_generator2 = image_datagen.flow_from_directory(
args['image_root'],
classes = [args['image_folder2']],
class_mode = None,
color_mode = image_color_mode,
target_size = target_size,
batch_size = args['batch_size'],
save_to_dir = save_to_dir,
save_prefix = image_save_prefix,
shuffle = shuffle,
seed = seed+1)
label_generator2 = label_datagen.flow_from_directory(
args['label_root'],
classes = [args['label_folder2']],
class_mode = None,
color_mode = label_color_mode,
target_size = target_size,
batch_size = args['batch_size'],
save_to_dir = save_to_dir,
save_prefix = label_save_prefix,
shuffle = shuffle,
seed = seed+1)
data_generator2 = zip(image_generator2, label_generator2)
while(True):
if np.random.randint(3) == 2:
for (img,label) in data_generator1:
# img,label = adjustData(img, label, args.data, cnt, 'F')
cnt+=img.shape[0]
yield (img,label)
else:
for (img,label) in data_generator2:
# img,label = adjustData(img, label, args.data, cnt, 'S')
cnt+=img.shape[0]
yield (img,label)
trainset = dataset(args, mode='train')
validset = dataset(args, mode='valid')
# next(validset)[0].shape --> (4,256,256,1)
# next(validset)[1].shape --> (4,256,256,1)
model = unet()
model.fit_generator(trainset, steps_per_epoch=500, shuffle=True, epochs=args['epoch'],validation_data=validset, validation_steps=2000)
使用fit_generator时出现此错误。手动堆叠X和y而不是使用fit_generator时,我会遇到相同的错误。
目标是二元类分割,其中X和y均为灰度图像。错误的回溯看起来像这样
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-25-b1490d73b6ae> in <module>
----> 1 model.fit_generator(trainset, steps_per_epoch=500, shuffle=True, epochs=args['epoch'],validation_data=validset, validation_steps=2000)
~/anaconda3/lib/python3.7/site-packages/keras/legacy/interfaces.py in wrapper(*args, **kwargs)
89 warnings.warn('Update your `' + object_name + '` call to the ' +
90 'Keras 2 API: ' + signature, stacklevel=2)
---> 91 return func(*args, **kwargs)
92 wrapper._original_function = func
93 return wrapper
~/anaconda3/lib/python3.7/site-packages/keras/engine/training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
1416 use_multiprocessing=use_multiprocessing,
1417 shuffle=shuffle,
-> 1418 initial_epoch=initial_epoch)
1419
1420 @interfaces.legacy_generator_methods_support
~/anaconda3/lib/python3.7/site-packages/keras/engine/training_generator.py in fit_generator(model, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
215 outs = model.train_on_batch(x, y,
216 sample_weight=sample_weight,
--> 217 class_weight=class_weight)
218
219 outs = to_list(outs)
~/anaconda3/lib/python3.7/site-packages/keras/engine/training.py in train_on_batch(self, x, y, sample_weight, class_weight)
1215 ins = x + y + sample_weights
1216 self._make_train_function()
-> 1217 outputs = self.train_function(ins)
1218 return unpack_singleton(outputs)
1219
~/anaconda3/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py in __call__(self, inputs)
2713 return self._legacy_call(inputs)
2714
-> 2715 return self._call(inputs)
2716 else:
2717 if py_any(is_tensor(x) for x in inputs):
~/anaconda3/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py in _call(self, inputs)
2673 fetched = self._callable_fn(*array_vals, run_metadata=self.run_metadata)
2674 else:
-> 2675 fetched = self._callable_fn(*array_vals)
2676 return fetched[:len(self.outputs)]
2677
~/anaconda3/lib/python3.7/site-packages/tensorflow/python/client/session.py in __call__(self, *args, **kwargs)
1456 ret = tf_session.TF_SessionRunCallable(self._session._session,
1457 self._handle, args,
-> 1458 run_metadata_ptr)
1459 if run_metadata:
1460 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
InvalidArgumentError: slice index 1 of dimension 3 out of bounds.
[[{{node loss/activation_1_loss/strided_slice_2}}]]
我还需要知道ImageDataGenerator的培训和验证目录应该如何工作,因为这不是分类问题。