对于第一次完整性检查我正试图建立一个网络,学习输出1
用于相同的图像对,0
用于非同一对图像,希望它能很快过拟合。
损失减少,但无论我尝试什么,准确度都会在0.5左右反弹。
使用ResNet50作为共享的暹罗分支对,我将它们与元素减法合并,并将得到的“差异层”输入单个sigmoid单位 - 如Siamese Neural Networks for One-shot Image Recognition中所述。我还尝试了其他一些建议的变化;例如softmax输出,连接而不是减法等等。
以下示例可以由任何人运行,前提是您提供一个包含至少2个图像的目录的路径,以便作为命令行参数进行匹配。
from keras.applications.resnet50 import ResNet50
from keras.models import Model
# from keras.utils.visualize_util import plot
from keras.layers import merge, \
Dense, \
Dropout, \
Input, \
GlobalAveragePooling2D, \
Lambda, \
BatchNormalization, \
Activation
from keras.layers.merge import Add, Multiply, Concatenate
from keras.optimizers import Adam, SGD, RMSprop
from keras.engine import Layer
import keras.backend as K
from keras import regularizers
import os
import random
from PIL import Image
import numpy as np
import cv2
def manhattan_distance(pair):
return K.sum(K.abs(pair[0]-pair[1]), axis=1, keepdims=True)
def _build_base_dense(input_shape):
input_tensor = Input(shape=input_shape)
base_model = ResNet50(weights='imagenet', include_top=False, input_tensor=input_tensor)
for layer in base_model.layers:
layer.trainable = False
one_thirty_tooth_resolution = base_model.get_layer('activation_49').output
pooled = GlobalAveragePooling2D()(one_thirty_tooth_resolution)
# dense_1 = Dense(1024, activation='relu', kernel_regularizer=regularizers.l2(0.01))(pooled)
# dense_1 = BatchNormalization()(dense_1)
# embedding_model = Model(inputs=input_tensor, outputs=dense_1)
embedding_model = Model(inputs=input_tensor, outputs=pooled)
return embedding_model
def build_siamese_dense(input_shape):
input_query = Input(shape=input_shape)
input_reference = Input(shape=input_shape)
base_network = _build_base_dense(input_shape=input_shape)
embed_query = base_network(input_query)
embed_reference = base_network(input_reference)
# dist = Lambda(manhattan_distance)([embed_query, embed_reference])
negative_embed_reference = Lambda(lambda x: x * -1)(embed_reference)
elementwise_dist = Add()([embed_query, negative_embed_reference]) #elementwise subtraction of each siamese leg
# merged = Concatenate()([embed_query, embed_reference])
# classify = Dense(2, activation='softmax')(dist)
classify = Dense(1, activation='sigmoid', use_bias=False)(elementwise_dist)
# classify = Dense(1, activation='sigmoid', use_bias=False)(merged)
model = Model(inputs=[input_query, input_reference], outputs=classify)
model.compile(
optimizer=Adam(lr=0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0),
# optimizer=SGD(lr=0.001, momentum=0.5),
# loss='categorical_crossentropy', metrics=['accuracy'])
loss='binary_crossentropy', metrics=['accuracy'])
return model
def preprocess_cv2_batch(images, dim_ordering='default'):
# images = images.astype(np.float64)
if dim_ordering == 'default':
dim_ordering = K.image_dim_ordering()
assert dim_ordering in {'tf', 'th'}
if dim_ordering == 'th':
# need to transpose axes to make (batch, channels, height, width)
print('Image batch arrived with shape: {}'.format(str(images.shape)))
images = np.transpose(images, (0, 3, 1, 2))
print('Image batch axes were transposed to shape: {} for THEANO dim-ordering convention'.format(
str(images.shape)))
# # 'RGB'->'BGR'
# x = x[:, ::-1, :, :]
# Zero-center by mean pixel
images[:, 0, :, :] -= 103.939
images[:, 1, :, :] -= 116.779
images[:, 2, :, :] -= 123.68
else:
# 'RGB'->'BGR'
# x = x[:, :, :, ::-1]
# # Zero-center by mean pixel
images[:, :, :, 0] -= 103.939
images[:, :, :, 1] -= 116.779
images[:, :, :, 2] -= 123.68
return images
class DataGenerator(object):
'''
Class for iterating through a directory of images, creating training pairs on the fly
'''
def __init__(self, image_dir, input_shape, prob_positive=0.5):
self.input_shape = input_shape
self.image_dir = image_dir
self.prob_positive = prob_positive
self.image_file_list = [os.path.join(self.image_dir, item) for item in os.listdir(self.image_dir)]
assert len(self.image_file_list) >= 2, 'You need at least 2 images in the dir to do matching.'
def generate_batch(self, batch_size, debug=False):
while True:
batch_query_inputs = []
batch_reference_inputs = []
batch_labels = []
num_successful = 0
while num_successful < batch_size:
try:
# randomly choose a reference image
input_pair = np.zeros((2, self.input_shape[0], self.input_shape[1], self.input_shape[2]), dtype=np.float32)
# sample an image without replacement
allowed_indices = range(len(self.image_file_list))
random_image_index = random.choice(allowed_indices)
allowed_indices.pop(random_image_index)
random_image_path = self.image_file_list[random_image_index]
random_image_reference = cv2.imread(random_image_path)
random_image_reference = cv2.resize(random_image_reference,(self.input_shape[1], self.input_shape[0]))
input_pair[1] = random_image_reference
# flip a coin to decide whether the training example is a match or not
if random.random() < self.prob_positive: # match
input_pair[0] = np.array(random_image_reference)
is_match = 1
else: # no match - choose a different image
random_image_index = random.choice(allowed_indices)
random_image_path = self.image_file_list[random_image_index]
random_image_query = cv2.imread(random_image_path)
random_image_query = cv2.resize(random_image_query,(self.input_shape[1], self.input_shape[0]))
input_pair[0] = random_image_query
is_match = 0
input_pair = preprocess_cv2_batch(input_pair)
batch_query_inputs.append(input_pair[0])
batch_reference_inputs.append(input_pair[1])
batch_labels.append(is_match)
# DEBUG
# cv2.namedWindow('query match={}'.format(is_match))
# cv2.imshow('query match={}'.format(is_match), input_pair[0])
# cv2.namedWindow('reference match={}'.format(is_match))
# cv2.imshow('reference match={}'.format(is_match), input_pair[1])
# cv2.waitKey()
# cv2.destroyAllWindows()
num_successful+=1
except cv2.error as cv2e:
print(cv2e)
# except Exception as e:
# print('There was some kind of exception...')
# print(e)
batch_query_inputs = np.array(batch_query_inputs)
batch_reference_inputs = np.array(batch_reference_inputs)
batch_labels = np.array(batch_labels)
yield [batch_query_inputs, batch_reference_inputs], batch_labels
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('imagedir', help='path to directory where training images are found')
args = parser.parse_args()
IMAGE_DIR = args.imagedir
INPUT_SHAPE = (320, 320, 3)
BATCH_SIZE = 32
NUM_ITERATIONS = 500
VAL_INTERVAL = 50
MODEL_NAME = 'siamese_experiment'
data_train = DataGenerator(IMAGE_DIR, INPUT_SHAPE)
data_val = DataGenerator(IMAGE_DIR, INPUT_SHAPE)
gen_train = data_train.generate_batch(batch_size=BATCH_SIZE)
gen_val = data_val.generate_batch(batch_size=BATCH_SIZE)
net = build_siamese_dense(input_shape=INPUT_SHAPE)
net.summary()
with open('{}.losshistory'.format(MODEL_NAME), 'wb') as f:
f.truncate()
with open('{}.acchistory'.format(MODEL_NAME), 'wb') as f:
f.truncate()
for iteration in range(NUM_ITERATIONS):
# do validation
if iteration % VAL_INTERVAL == 0:
print('============\nIteration: {}'.format(iteration))
batch_X, batch_y = gen_val.next()
metrics_val = net.evaluate(batch_X, batch_y, batch_size=BATCH_SIZE, verbose=1)
print('VALIDATION: Loss={}, Acc={}'.format(metrics_val[0], metrics_val[1]))
batch_X, batch_y = gen_train.next()
metrics_train = net.train_on_batch(batch_X, batch_y)
print('============\nIteration: {}'.format(iteration))
print('TRAIN: Loss={}, Acc={}'.format(metrics_train[0], metrics_train[1]))
print('============')
with open('{}.losshistory'.format(MODEL_NAME), 'a') as f:
f.write('{}\n'.format(metrics_train[0]))
with open('{}.acchistory'.format(MODEL_NAME), 'a') as f:
f.write('{}\n'.format(metrics_train[1]))
答案 0 :(得分:0)
当我基于 Restnet50 训练 Siamese 网络时,acc 总是 50%,和你一样。但是当我规范化 imgs(即 img/255.)时,网络开始学习……所以我建议你尝试规范化你的数据,这可能是由 Resnet50 的初始权重引起的,规范化的数据总是有助于改进你的模型