如何在Keras训练暹罗网络?

时间:2019-02-11 18:41:11

标签: machine-learning keras deep-learning computer-vision conv-neural-network

我有一个pandas数据框,其中包含以下正例和负例的文件名

img1        img2      y
001.jpg     002.jpg   1 
003.jpg     004.jpg   0 
003.jpg     002.jpg   1  

我想使用Keras ImageDataGenerator和flow_from_dataframe来训练我的暹罗网络。如何设置训练,以便代码同时输入2张带有1个标签的图像。

下面是我的模型的代码

def siamese_model(input_shape) :
    left = Input(input_shape)
    right = Input(input_shape)
    model = Sequential()
    model.add(Conv2D(32, (3,3), activation='relu', input_shape=input_shape))
    model.add(BatchNormalization())
    model.add(Conv2D(64, (3,3), activation='relu'))
    model.add(BatchNormalization())
    model.add(Conv2D(128, (3,3), activation='relu'))
    model.add(BatchNormalization())
    model.add(Conv2D(256, (3,3), activation='relu')
    model.add(BatchNormalization())
    model.add(Conv2D(256, (3,3), activation='relu')
    model.add(MaxPooling2D())
    model.add(BatchNormalization())
    model.add(Flatten())
    model.add(Dense(512, activation='sigmoid'))

    left_encoded = model(left)
    right_encoded = model(right)
    L1_layer = Lambda(lambda tensors:K.abs(tensors[0] - tensors[1]))
    L1_distance = L1_layer([left_encoded, right_encoded])
    prediction = Dense(1,activation='sigmoid')(L1_distance)
    siamese_net = Model(inputs=[left,right],outputs=prediction)
    return siamese_net

model = siamese_model((224,224,3))
model.compile(loss="binary_crossentropy",optimizer="adam", metrics=['accuracy'])

datagen_left = ImageDataGenerator(rotation_range=10,
                    width_shift_range=0.2,
                    height_shift_range=0.2,
                    shear_range=0.2,
                    zoom_range=0.2,
                    vertical_flip = True)
datagen_right = ImageDataGenerator(rotation_range=10,
                    width_shift_range=0.2,
                    height_shift_range=0.2,
                    shear_range=0.2,
                    zoom_range=0.2,
                    vertical_flip = True)

1 个答案:

答案 0 :(得分:4)

将生成器加入自定义生成器中。

让其中一个输出所需的标签,放弃另一个标签。

class DoubleGenerator(Sequence):
    def __init__(self, gen1, gen2):
       self.gen1 = gen1
       self.gen2 = gen2

    def __len__(self):
       return len(self.gen1)

    def __getitem__(self, i):
       x1,y = self.gen1[i]
       x2,y2 = self.gen2[i]
       return (x1,x2), y

使用它:

double_gen = DoubleGenerator(datagen_left.flow_from_directory(...),
                             datagen_right.flow_from_directory(...))