预处理使用keras函数ImageDataGenerator()生成的图像来训练resnet50模型

时间:2018-05-02 11:17:45

标签: python keras generator resnet image-preprocessing

我正在尝试训练resnet50模型用于图像分类问题。我已经加载了' imagenet'在我拥有的图像数据集上训练模型之前预训练的权重。我正在使用keras函数flow_from_directory()从目录加载图像。

train_datagen = ImageDataGenerator()
train_generator = train_datagen.flow_from_directory(
        './train_qcut_2_classes',
        batch_size=batch_size,
        shuffle=True,
        target_size=input_size[1:],
        class_mode='categorical')  
test_datagen = ImageDataGenerator()
validation_generator = test_datagen.flow_from_directory(
        './validate_qcut_2_classes',
        batch_size=batch_size,
        target_size=input_size[1:],
        shuffle=True,
        class_mode='categorical')

我将生成器作为fit_generator函数中的参数传递。

hist2=model.fit_generator(train_generator,
                        samples_per_epoch=102204,
                        validation_data=validation_generator,
                        nb_val_samples=25547,
                        nb_epoch=80, callbacks=callbacks,
                        verbose=1)

问题:

使用此设置,如何在将输入图像传递给模型之前使用preprocess_input()函数对其进行预处理?

from keras.applications.resnet50 import preprocess_input

我尝试使用preprocessing_function参数,如下所示

train_datagen=ImageDataGenerator(preprocessing_function=preprocess_input)
train_generator = train_datagen.flow_from_directory(
        './train_qcut_2_classes',
        batch_size=batch_size,
        shuffle=True,
        target_size=input_size[1:],
        class_mode='categorical')  
test_datagen = ImageDataGenerator(preprocessing_function=preprocess_input)
validation_generator = test_datagen.flow_from_directory(
        './validate_qcut_2_classes',
        batch_size=batch_size,
        target_size=input_size[1:],
        shuffle=True,
        class_mode='categorical')

当我尝试提取预处理输出时,我得到了以下结果。

train_generator.next()[0][0]

array([[[  91.06099701,   80.06099701,   96.06099701, ...,   86.06099701,
       52.06099701,   12.06099701],
    [ 101.06099701,  104.06099701,  118.06099701, ...,  101.06099701,
       63.06099701,   19.06099701],
    [ 117.06099701,  103.06099701,   88.06099701, ...,   88.06099701,
       74.06099701,   18.06099701],
    ..., 
    [-103.93900299, -103.93900299, -103.93900299, ...,  -24.93900299,
      -38.93900299,  -24.93900299],
    [-103.93900299, -103.93900299, -103.93900299, ...,  -52.93900299,
      -27.93900299,  -39.93900299],
    [-103.93900299, -103.93900299, -103.93900299, ...,  -45.93900299,
      -29.93900299,  -28.93900299]],

   [[  81.22100067,   70.22100067,   86.22100067, ...,   69.22100067,
       37.22100067,   -0.77899933],
    [  91.22100067,   94.22100067,  108.22100067, ...,   86.22100067,
       50.22100067,    6.22100067],
    [ 107.22100067,   93.22100067,   78.22100067, ...,   73.22100067,
       62.22100067,    6.22100067],
    ..., 
    [-116.77899933, -116.77899933, -116.77899933, ...,  -36.77899933,
      -50.77899933,  -36.77899933],
    [-116.77899933, -116.77899933, -116.77899933, ...,  -64.77899933,
      -39.77899933,  -51.77899933],
    [-116.77899933, -116.77899933, -116.77899933, ...,  -57.77899933,
      -41.77899933,  -40.77899933]],

   [[  78.31999969,   67.31999969,   83.31999969, ...,   61.31999969,
       29.31999969,   -7.68000031],
    [  88.31999969,   91.31999969,  105.31999969, ...,   79.31999969,
       43.31999969,   -0.68000031],
    [ 104.31999969,   90.31999969,   75.31999969, ...,   66.31999969,
       53.31999969,   -2.68000031],
    ..., 
    [-123.68000031, -123.68000031, -123.68000031, ...,  -39.68000031,
      -53.68000031,  -39.68000031],
    [-123.68000031, -123.68000031, -123.68000031, ...,  -67.68000031,
      -42.68000031,  -54.68000031],
    [-123.68000031, -123.68000031, -123.68000031, ...,  -60.68000031,
      -44.68000031,  -43.68000031]]], dtype=float32)

为了确认这一点,我直接在特定图像上使用预处理功能

import cv2
img = cv2.imread('./images.jpg')
img = img_to_array(img)
x = np.expand_dims(img, axis=0)
x = x.astype(np.float64)
x = preprocess_input(x)

,它给出了以下输出,

array([[[[ 118.061,  125.061,  134.061, ...,   97.061,   99.061,  102.061],
     [ 118.061,  125.061,  133.061, ...,   98.061,  100.061,  102.061],
     [ 113.061,  119.061,  126.061, ...,  100.061,  101.061,  102.061],
     ..., 
     [  65.061,   64.061,   64.061, ...,   60.061,   61.061,   57.061],
     [  64.061,   64.061,   63.061, ...,   66.061,   67.061,   59.061],
     [  56.061,   59.061,   62.061, ...,   61.061,   60.061,   59.061]],

    [[ 113.221,  120.221,  129.221, ...,  112.221,  114.221,  113.221],
     [ 116.221,  123.221,  131.221, ...,  113.221,  115.221,  113.221],
     [ 118.221,  124.221,  131.221, ...,  115.221,  116.221,  113.221],
     ..., 
     [  56.221,   55.221,   55.221, ...,   51.221,   52.221,   51.221],
     [  55.221,   55.221,   54.221, ...,   57.221,   58.221,   53.221],
     [  47.221,   50.221,   53.221, ...,   52.221,   51.221,   50.221]],

    [[ 109.32 ,  116.32 ,  125.32 , ...,  106.32 ,  108.32 ,  108.32 ],
     [ 111.32 ,  118.32 ,  126.32 , ...,  107.32 ,  109.32 ,  108.32 ],
     [ 111.32 ,  117.32 ,  124.32 , ...,  109.32 ,  110.32 ,  108.32 ],
     ..., 
     [  34.32 ,   33.32 ,   33.32 , ...,   30.32 ,   31.32 ,   26.32 ],
     [  33.32 ,   33.32 ,   32.32 , ...,   36.32 ,   37.32 ,   28.32 ],
     [  25.32 ,   28.32 ,   31.32 , ...,   30.32 ,   29.32 ,   28.32 ]]]])

关于为什么会发生这种情况的任何想法?

1 个答案:

答案 0 :(得分:2)

创建ImageDataGenerator时的参数:

train_datagen = ImageDataGenerator(preprocessing_function=preprocess_input)