为什么这个用于图像匹配的Keras暹罗网络没有学到任何东西?

时间:2017-05-30 12:51:18

标签: neural-network computer-vision deep-learning keras conv-neural-network

对于第一次完整性检查我正试图建立一个网络,学习输出1用于相同的图像对,0用于非同一对图像,希望它能很快过拟合。

损失减少,但无论我尝试什么,准确度都会在0.5左右反弹。

使用ResNet50作为共享的暹罗分支对,我将它们与元素减法合并,并将得到的“差异层”输入单个sigmoid单位 - 如Siamese Neural Networks for One-shot Image Recognition中所述。我还尝试了其他一些建议的变化;例如softmax输出,连接而不是减法等等。

以下示例可以由任何人运行,前提是您提供一个包含至少2个图像的目录的路径,以便作为命令行参数进行匹配。

from keras.applications.resnet50 import ResNet50
from keras.models import Model
# from keras.utils.visualize_util import plot
from keras.layers import merge, \
    Dense, \
    Dropout, \
    Input, \
    GlobalAveragePooling2D, \
    Lambda, \
    BatchNormalization, \
    Activation
from keras.layers.merge import Add, Multiply, Concatenate
from keras.optimizers import Adam, SGD, RMSprop
from keras.engine import Layer
import keras.backend as K
from keras import regularizers
import os
import random
from PIL import Image
import numpy as np
import cv2


def manhattan_distance(pair):
   return K.sum(K.abs(pair[0]-pair[1]), axis=1, keepdims=True)

def _build_base_dense(input_shape):
    input_tensor = Input(shape=input_shape)

    base_model = ResNet50(weights='imagenet', include_top=False, input_tensor=input_tensor)

    for layer in base_model.layers:
        layer.trainable = False

    one_thirty_tooth_resolution = base_model.get_layer('activation_49').output

    pooled = GlobalAveragePooling2D()(one_thirty_tooth_resolution)
    # dense_1 = Dense(1024, activation='relu', kernel_regularizer=regularizers.l2(0.01))(pooled)
    # dense_1 = BatchNormalization()(dense_1)

    # embedding_model = Model(inputs=input_tensor, outputs=dense_1)
    embedding_model = Model(inputs=input_tensor, outputs=pooled)

    return embedding_model

def build_siamese_dense(input_shape):
    input_query = Input(shape=input_shape)
    input_reference = Input(shape=input_shape)

    base_network = _build_base_dense(input_shape=input_shape)

    embed_query = base_network(input_query)
    embed_reference = base_network(input_reference)

    # dist = Lambda(manhattan_distance)([embed_query, embed_reference])
    negative_embed_reference = Lambda(lambda x: x * -1)(embed_reference)
    elementwise_dist = Add()([embed_query, negative_embed_reference]) #elementwise subtraction of each siamese leg
    # merged = Concatenate()([embed_query, embed_reference])

    # classify = Dense(2, activation='softmax')(dist)
    classify = Dense(1, activation='sigmoid', use_bias=False)(elementwise_dist)
    # classify = Dense(1, activation='sigmoid', use_bias=False)(merged)

    model = Model(inputs=[input_query, input_reference], outputs=classify)

    model.compile(
        optimizer=Adam(lr=0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0),
        # optimizer=SGD(lr=0.001, momentum=0.5),
        # loss='categorical_crossentropy', metrics=['accuracy'])
        loss='binary_crossentropy', metrics=['accuracy'])

    return model

def preprocess_cv2_batch(images, dim_ordering='default'):
    # images = images.astype(np.float64)
    if dim_ordering == 'default':
        dim_ordering = K.image_dim_ordering()
        assert dim_ordering in {'tf', 'th'}

        if dim_ordering == 'th':
            # need to transpose axes to make (batch, channels, height, width)
            print('Image batch arrived with shape: {}'.format(str(images.shape)))
            images = np.transpose(images, (0, 3, 1, 2))
            print('Image batch axes were transposed to shape: {} for THEANO dim-ordering convention'.format(
                str(images.shape)))
            # # 'RGB'->'BGR'
            # x = x[:, ::-1, :, :]
            # Zero-center by mean pixel
            images[:, 0, :, :] -= 103.939
            images[:, 1, :, :] -= 116.779
            images[:, 2, :, :] -= 123.68
        else:
            # 'RGB'->'BGR'
            # x = x[:, :, :, ::-1]
            # # Zero-center by mean pixel
            images[:, :, :, 0] -= 103.939
            images[:, :, :, 1] -= 116.779
            images[:, :, :, 2] -= 123.68
        return images

class DataGenerator(object):
    '''
    Class for iterating through a directory of images, creating training pairs on the fly
    '''
    def __init__(self, image_dir, input_shape, prob_positive=0.5):
        self.input_shape = input_shape
        self.image_dir = image_dir
        self.prob_positive = prob_positive
        self.image_file_list = [os.path.join(self.image_dir, item) for item in os.listdir(self.image_dir)]
        assert len(self.image_file_list) >= 2, 'You need at least 2 images in the dir to do matching.'

    def generate_batch(self, batch_size, debug=False):
        while True:
            batch_query_inputs = []
            batch_reference_inputs = []
            batch_labels = []
            num_successful = 0
            while num_successful < batch_size:
                try:
                    # randomly choose a reference image
                    input_pair = np.zeros((2, self.input_shape[0], self.input_shape[1], self.input_shape[2]), dtype=np.float32)
                    # sample an image without replacement
                    allowed_indices = range(len(self.image_file_list))
                    random_image_index = random.choice(allowed_indices)
                    allowed_indices.pop(random_image_index)
                    random_image_path = self.image_file_list[random_image_index]
                    random_image_reference = cv2.imread(random_image_path)
                    random_image_reference = cv2.resize(random_image_reference,(self.input_shape[1], self.input_shape[0]))
                    input_pair[1] = random_image_reference

                    # flip a coin to decide whether the training example is a match or not
                    if random.random() < self.prob_positive: # match
                        input_pair[0] = np.array(random_image_reference)
                        is_match = 1
                    else: # no match - choose a different image
                        random_image_index = random.choice(allowed_indices)
                        random_image_path = self.image_file_list[random_image_index]
                        random_image_query = cv2.imread(random_image_path)
                        random_image_query = cv2.resize(random_image_query,(self.input_shape[1], self.input_shape[0]))
                        input_pair[0] = random_image_query
                        is_match = 0

                    input_pair = preprocess_cv2_batch(input_pair)
                    batch_query_inputs.append(input_pair[0])
                    batch_reference_inputs.append(input_pair[1])
                    batch_labels.append(is_match)

                    # DEBUG
                    # cv2.namedWindow('query match={}'.format(is_match))
                    # cv2.imshow('query match={}'.format(is_match), input_pair[0])
                    # cv2.namedWindow('reference match={}'.format(is_match))
                    # cv2.imshow('reference match={}'.format(is_match), input_pair[1])
                    # cv2.waitKey()
                    # cv2.destroyAllWindows()

                    num_successful+=1
                except cv2.error as cv2e:
                    print(cv2e)
                # except Exception as e:
                #     print('There was some kind of exception...')
                #     print(e)
            batch_query_inputs = np.array(batch_query_inputs)
            batch_reference_inputs = np.array(batch_reference_inputs)
            batch_labels = np.array(batch_labels)

            yield [batch_query_inputs, batch_reference_inputs], batch_labels


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument('imagedir', help='path to directory where training images are found')
    args = parser.parse_args()

    IMAGE_DIR = args.imagedir
    INPUT_SHAPE = (320, 320, 3)
    BATCH_SIZE = 32
    NUM_ITERATIONS = 500
    VAL_INTERVAL = 50
    MODEL_NAME = 'siamese_experiment'

    data_train = DataGenerator(IMAGE_DIR, INPUT_SHAPE)
    data_val = DataGenerator(IMAGE_DIR, INPUT_SHAPE)
    gen_train = data_train.generate_batch(batch_size=BATCH_SIZE)
    gen_val = data_val.generate_batch(batch_size=BATCH_SIZE)

    net = build_siamese_dense(input_shape=INPUT_SHAPE)
    net.summary()

    with open('{}.losshistory'.format(MODEL_NAME), 'wb') as f:
        f.truncate()
    with open('{}.acchistory'.format(MODEL_NAME), 'wb') as f:
        f.truncate()
    for iteration in range(NUM_ITERATIONS):
        # do validation
        if iteration % VAL_INTERVAL == 0:
            print('============\nIteration: {}'.format(iteration))
            batch_X, batch_y = gen_val.next()
            metrics_val = net.evaluate(batch_X, batch_y, batch_size=BATCH_SIZE, verbose=1)
            print('VALIDATION: Loss={}, Acc={}'.format(metrics_val[0], metrics_val[1]))

        batch_X, batch_y = gen_train.next()
        metrics_train = net.train_on_batch(batch_X, batch_y)

        print('============\nIteration: {}'.format(iteration))
        print('TRAIN: Loss={}, Acc={}'.format(metrics_train[0], metrics_train[1]))
        print('============')
        with open('{}.losshistory'.format(MODEL_NAME), 'a') as f:
            f.write('{}\n'.format(metrics_train[0]))
        with open('{}.acchistory'.format(MODEL_NAME), 'a') as f:
            f.write('{}\n'.format(metrics_train[1]))

1 个答案:

答案 0 :(得分:0)

当我基于 Restnet50 训练 Siamese 网络时,acc 总是 50%,和你一样。但是当我规范化 imgs(即 img/255.)时,网络开始学习……所以我建议你尝试规范化你的数据,这可能是由 Resnet50 的初始权重引起的,规范化的数据总是有助于改进你的模型