多输入卷积神经网络的图像分类

时间:2018-11-11 19:18:04

标签: tensorflow machine-learning keras deep-learning computer-vision

我是深度学习的初学者,我想在Keras中为图像分类创建多输入卷积神经网络(CNN)模型。

我正在创建一个CNN模型,该模型可以获取两张图片,并提供一张输出,这是两个图片的

我有两个数据集:type1和type2,每个数据集都包含相同的类,但是数据集type1中每个类的图像数大于中的图像数。数据集类型2中的每个类。 该模型应从Type1数据集中获取一幅图像,从Type2数据集中获取一幅图像,然后将这些图像分类为一类(ClassA或ClassB或------)。

以下是数据集的结构。

Type1 dataset
|Train
              |ClassA
                             |image1
                             |image2
                             |image3
                             |image4
                            -----
              |ClassB
                             |image1
                             |image2
                             |image3
                             |image4
                            -----
              |ClassC
                             |image1
                             |image2
                             |image3
                             |image4
                            -----
              |ClassD
                             |image1
                             |image2
                             |image3
                             |image4
                            -----
       ----------------
|Validate
            -----------
|Test
           --------------

Type2 dataset
|Train
              |ClassA
                             |image1
                             |image2
                            -----
              |ClassB
                             |image1
                             |image2
                            -----
              |ClassC
                             |image1
                             |image2
                            -----
              |ClassD
                             |image1
                             |image2
                            -----
       ----------------
|Validate
            -----------
|Test
           --------------

该模型与该图像中的模型非常相似,但是在展平层之前它具有更多的层。

Multi-input

我创建了一个自定义生成器,该生成器输入两个图像(来自类型1和2),并且来自类型1的每个图像都与来自类型2的每个图像配对,只要这些图像属于相同类(标签)

问题是执行fit_generator时出现如下所示的无限循环:

  Found *** images belonging to 100 classes.

    Found *** images belonging to 100 classes.
    Found *** images belonging to 100 classes.
    Found *** images belonging to 100 classes.
    Found *** images belonging to 100 classes.
    Found *** images belonging to 100 classes.
    Found *** images belonging to 100 classes.
    Found *** images belonging to 100 classes.
    Found *** images belonging to 100 classes.
    Found *** images belonging to 100 classes.
    Found *** images belonging to 100 classes.
    Found *** images belonging to 100 classes. ......
.................................................................

这是我的自定义生成器代码:

input_imgen = ImageDataGenerator( 
                                  rotation_range=10,
                                  shear_range=0.2,
                                  zoom_range=0.1,
                                  width_shift_range=0.1,
                                  height_shift_range=0.1
                                  )



test_imgen = ImageDataGenerator()



def generate_generator_multiple(generator,dir1, dir2, batch_size, img_height,img_width):


    genX1 = generator.flow_from_directory(dir1,
                                          target_size = (img_height,img_width),
                                          class_mode = 'categorical',
                                          batch_size = batch_size,
                                          shuffle=False, 
                                          seed=7)

    genX2 = generator.flow_from_directory(dir2,
                                          target_size = (img_height,img_width),
                                          class_mode = 'categorical',
                                          batch_size = batch_size,
                                          shuffle=False, 
                                          seed=7)
    while True:
      X2i = genX2.next() 
      Type1 = []
      Type2 = []
      image1 = []
      image2 = []

      while True:
        X1i = genX1.next() 
        for i in range(len(X2i[1])): #Type2
          for j in range(len(X1i[1])): #Type1
            if all(X2i[1][i]) == all(X1i[1][j]): # have same label
              image1.append(X1i[0][j]) # add image
              image1.append(X1i[1][j]) # add label
              image2.append(X2i[0][i]) # add image
              image2.append(X2i[1][i]) # add label
      Type1.append(image1)
      Type2.append(image2)
      yield [Type1 [0], Type2 [0]], Type2 [1]  #Yield both images and their mutual label


inputgenerator=generate_generator_multiple(generator=input_imgen,
                                           dir1=train_iris_data,
                                           dir2=train_face_data,
                                           batch_size=32,
                                           img_height=224,
                                           img_width=224)       

validgenerator=generate_generator_multiple(generator=test_imgen,
                                          dir1=valid_iris_data,
                                          dir2=valid_face_data,
                                          batch_size=32,
                                          img_height=224,
                                          img_width=224) 

testgenerator=generate_generator_multiple(generator=test_imgen,
                                          dir1=test_face_data,
                                          dir2=test_face_data,
                                          batch_size=32,
                                          img_height=224,
                                          img_width=224)


    # compile the model
    multi_model.compile(
            loss='categorical_crossentropy',
            optimizer=Adam(lr=0.0001),
            metrics=['accuracy']
        )


# train the model and save the history
history = multi_model.fit_generator(
inputgenerator,
steps_per_epoch=len(train_data) // batch_size,
epochs=10,
verbose=1,
validation_data=validgenerator,
validation_steps=len(valid_data) // batch_size,
use_multiprocessing=True,
shuffle=False
)

能帮我解决这个问题并创建自定义生成器吗?

0 个答案:

没有答案