如何配置Keras模型来预测图像?

时间:2018-11-20 22:06:06

标签: python machine-learning keras deep-learning image-segmentation

主要任务是预测输入图像的蒙版。因此,我需要以下数据进行训练:

  • 很多768x768原始图片,如下所示:

enter image description here

  • 并像这样输出掩码图片(也为768x768):

enter image description here

我还有验证的原始照片。

我准备了某种可以预测输出掩码的神经模型。我准备的keras model configuaration应该具有如下所示的拓扑:

enter image description here

我准备进行培训的代码在那里。

import keras
epochs=100

image_datagen = keras.preprocessing.image.ImageDataGenerator()
mask_datagen = keras.preprocessing.image.ImageDataGenerator()
seed = 1
image_generator = image_datagen.flow_from_directory(
    'H:/LABS/ship_detection/test_train/',
    color_mode='rgb',batch_size=32,target_size=(768,768),
    seed=seed)

mask_generator = mask_datagen.flow_from_directory(
    'H:/LABS/ship_detection/test_mask/',
    class_mode="categorical",batch_size=32,target_size=(768,768),
    seed=seed)

train_generator = zip(image_generator, mask_generator)

model.fit_generator(generator=train_generator,
                    epochs=epochs,
                    callbacks=callbacks,steps_per_epoch=1)

但是当我尝试使用生成器进行预测时,我遇到了一个问题:

c:\users\harwister\appdata\local\programs\python\python36\lib\site-packages\keras\engine\training_generator.py in fit_generator(model, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
    208                     batch_size = list(x.values())[0].shape[0]
    209                 else:
--> 210                     batch_size = x.shape[0]
    211                 batch_logs['batch'] = batch_index
    212                 batch_logs['size'] = batch_size

AttributeError: 'tuple' object has no attribute 'shape'

我肯定做错了事,但是从这些错误中我什么也听不懂。我在Google中找不到响应的一个简单问题是:如何将两幅图像(输入和输出图像)推入Keras进行训练,训练后如何获得提供输入图像的输出图像?

1 个答案:

答案 0 :(得分:0)

由于图像和标签(即蒙版)具有单独的生成器,因此需要将List<String> as = new ArrayList<String>(); HttpTransportProperties.Authenticator basicAuth = new HttpTransportProperties.Authenticator(); as.add(Authenticator.BASIC); basicAuth.setAuthSchemes(as); basicAuth.setUsername("ABC"); basicAuth.setPassword("password"); basicAuth.setPreemptiveAuthentication(true); serviceStub._getServiceClient().getOptions().setProperty( org.apache.axis2.transport.http.HTTPConstants.AUTHENTICATE, basicAuthenticator); 参数设置为class_mode,以防止生成器生成任何标签数组:

None

通过这种方式,image_generator = image_datagen.flow_from_directory(class_mode=None, ...) mask_generator = mask_datagen.flow_from_directory(class_mode=None, ...) 仅生成输入图像,而image_generator仅生成蒙版(即真实标签)图像。