我已经建立了一个自定义模型,现在想在Google colab上使用我的Google驱动器中的数据集(数据集属于kaggle中的Google Landmark Recognition Challenge)。 使用im ImageDataGenerator(flow_from_directory)和FitGenerator。
运行fit_generator时,它在第一个纪元的最后一步(验证之前)失败。
我最小化了数据集和模型,以尽可能缩短运行时间,并检查数据的形状是否正确,但看起来是否正确
它实际上是我曾经建立了第一个模型和Im在ML世界还挺新的,所以我觉得我真是怀念这里的东西..
下面是我的代码(在没有一些不需要的部分)
rotation_range=0,
width_shift_range=0,
height_shift_range=0,
shear_range=0,
zoom_range=[0.8, 1.25],
horizontal_flip=True,
vertical_flip=False,
fill_mode='reflect',
data_format='channels_last',
brightness_range=[0.5, 1.5])
test_datagen = ImageDataGenerator(
rotation_range=30,
width_shift_range=0.1,
height_shift_range=0.1,
shear_range=0.01,
zoom_range=[0.8, 1.25],
horizontal_flip=True,
vertical_flip=False,
fill_mode='reflect',
data_format='channels_last',
brightness_range=[0.5, 1.5])
image_size = (128,128)
batch_size = 20
train_generator=train_datagen.flow_from_directory(TRAINING_DATA_DIR,
target_size=image_size,
color_mode='rgb',
batch_size=batch_size,
class_mode='categorical',
shuffle=True)
validation_generator = test_datagen.flow_from_directory(
VALIDATION_DATA_DIR,
target_size=image_size,
color_mode='rgb',
class_mode = "categorical")
model = Sequential()
model.add(Conv2D(64, (3,3), input_shape = (image_size[0], image_size[1], 3)))
model.add(Activation("relu"))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Conv2D(32, (3,3)))
model.add(Activation("relu"))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Flatten())
model.add(Dense(150))
model.add(Activation("relu"))
model.add(Dense(100))
model.add(Activation('softmax'))
model.summary()
model.compile(loss = "categorical_crossentropy", optimizer = optimizers.SGD(lr=0.0001, momentum=0.9), metrics=["accuracy"])
// Calculate the steps sizes
step_size_train = train_generator.n//train_generator.batch_size
step_size_validation = validation_generator.n//validation_generator.batch_size
// We define an early stopper to avoid wasting compute resources and time
early = EarlyStopping(monitor='val_acc', min_delta=0, patience=2, verbose=1, mode='auto')
networkfileName = RESULTS_PATH + "/" + MODEL_NAME + "{}" + MODEL_ENDING.format(int(time.time()))
checkpoint = ModelCheckpoint(networkfileName, monitor='val_acc', verbose=1,save_best_only=True, save_weights_only=False, mode='auto', period=1)
history = model.fit_generator(generator=train_generator,
steps_per_epoch=step_size_train,
validation_data = validation_generator,
validation_steps = step_size_validation,
epochs=5,
callbacks = [checkpoint, early],
workers=1,
use_multiprocessing=False,
verbose=1)
不幸的是我没有足够多的声誉,所以我不能添加一个截图,所以这里的链接:Exception screenshot
还在此处添加了异常输出:
IndexError Traceback (most recent call last)
<ipython-input-63-db9c44988a08> in <module>()
7 workers=1,
8 use_multiprocessing=False,
----> 9 verbose=1)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
2175 use_multiprocessing=use_multiprocessing,
2176 shuffle=shuffle,
-> 2177 initial_epoch=initial_epoch)
2178
2179 def evaluate_generator(self,
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training_generator.py in fit_generator(model, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
214 break
215
--> 216 callbacks.on_epoch_end(epoch, epoch_logs)
217 epoch += 1
218 if callbacks.model.stop_training:
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/callbacks.py in on_epoch_end(self, epoch, logs)
212 logs = logs or {}
213 for callback in self.callbacks:
--> 214 callback.on_epoch_end(epoch, logs)
215
216 def on_batch_begin(self, batch, logs=None):
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/callbacks.py in on_epoch_end(self, epoch, logs)
572 if self.epochs_since_last_save >= self.period:
573 self.epochs_since_last_save = 0
--> 574 filepath = self.filepath.format(epoch=epoch + 1, **logs)
575 if self.save_best_only:
576 current = logs.get(self.monitor)
IndexError: tuple index out of range
谢谢!