使用Talos和flow_from_directory在图像中进行超参数优化

时间:2019-04-06 17:31:51

标签: keras conv-neural-network grid-search hyperparameters talos

我试图优化我的角膜CNN用于图像分类的超参数。我考虑使用sklearn和talos优化器(https://github.com/autonomio/talos)中的网格搜索。我克服了从flow_from_directory(下面的代码)制作x和y的基本困难,但是...仍然无法正常工作!任何想法?也许有人遇到同样的问题。


def talos_model(train_flow, validation_flow, nb_train_samples, nb_validation_samples, params):

    model = Sequential()

    model.add(Conv2D(6,(5,5),activation="relu",padding="same",
                     input_shape=(img_width, img_height, 3)))
    model.add(MaxPooling2D((2,2)))
    model.add(Dropout(params['dropout']))

    model.add(Conv2D(16,(5,5),activation="relu"))
    model.add(MaxPooling2D((2,2)))
    model.add(Dropout(params['dropout']))

    model.add(Flatten())

    model.add(Dense(120, activation='relu', kernel_initializer=params['kernel_initializer']))
    model.add(Dropout(params['dropout']))
    model.add(Dense(84, activation='relu', kernel_initializer=params['kernel_initializer']))
    model.add(Dropout(params['dropout']))
    model.add(Dense(10, activation='softmax'))

    model.compile(loss=params['loss'],
                optimizer=params['optimizer'],
                metrics=['categorical_accuracy'])

    checkpointer = ModelCheckpoint(filepath='talos_cnn.h5py', 
                               monitor='val_categorical_accuracy', save_best_only=True)

    history = model.fit_generator(generator=train_flow, 
                    samples_per_epoch=nb_train_samples,
                    validation_data=validation_flow,
                    nb_val_samples=nb_validation_samples,
                    callbacks=[checkpointer],
                    verbose=1,
                    epochs=params['epochs'])

    return history, model

train_generator = ImageDataGenerator(rescale=1/255)

validation_generator = ImageDataGenerator(rescale=1/255)

# Retrieve images and their classes for train and validation sets
train_flow = train_generator.flow_from_directory(directory=train_data_dir, 
                                                 batch_size=batch_size, 
                                                 target_size(img_height,img_width))

validation_flow = validation_generator.flow_from_directory(directory=validation_data_dir, 
                                                           batch_size=batch_size,
                                                           target_size=(img_height,img_width),
                                                            shuffle = False)

#here I make x and y for talos
(X_train, Y_train) = train_flow.next()

#starting an experiment with talos
t = ta.Scan(x=X_train,
            y=Y_train,
            model=talos_model,
            params=params,
            dataset_name='landmarks',
            experiment_no='1')

错误发生在最后一行:

具有多个元素的数组的真值是不明确的。使用a.any()或a.all()

0 个答案:

没有答案