暹罗网络上的损失不会减少

时间:2019-06-11 20:03:49

标签: python keras neural-network deep-learning

我对机器学习非常陌生,我开始实施一个暹罗网络来检查手写数字的相似度,并使用MNIST数据集进行训练,但是我遇到了严重的丢失问题。

网络模型

import keras
from keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense, Lambda
from keras.models import Sequential, Model
from keras.optimizers import Adam
import keras.backend as K
import cv2
from keras.datasets import mnist
import numpy as np
import random

def siameseNet(input_shape):
    input1 = Input(input_shape)
    input2 = Input(input_shape)

    model = Sequential()
    model.add(Conv2D(50, (5,5), activation='relu', input_shape=input_shape))
    model.add(MaxPooling2D())
    model.add(Conv2D(100, (3,3), activation='relu'))
    model.add(MaxPooling2D())
    model.add(Conv2D(100, (3,3), activation='relu'))

    model.add(Flatten())
    model.add(Dense(2048, activation='sigmoid'))

    input_model_1 = model(input1)
    input_model_2 = model(input2)
    distance_func = Lambda(lambda t: K.abs(t[0]-t[1]))
    distance_layer = distance_func([input_model_1, input_model_2])

    prediction = Dense(1,activation='sigmoid')(distance_layer)

    network = Model(inputs=[input1,input2],outputs=prediction)

    return network

训练数据

我的pairs对象是一个numpy数组,其中有两个数组,包含相同索引上的图像,该数组的上半部分是同一类别的图像,下半部分是不同类别的图像。

category对象是一个简单数组,其中包含来自训练集中的相同数量的样本,其上半部分使用0来指定相同图像的Y值,下半部分设置到1

在以下函数中同时填充了pairscategory

INPUT_SHAPE = (28,28,1)

def loadData():
    (X_train, Y_train), _ = mnist.load_data()
    n_samples = 20000
    arrPairs = [np.zeros((n_samples, INPUT_SHAPE[0], INPUT_SHAPE[1],INPUT_SHAPE[2])) for i in range(2)]
    category = np.zeros((n_samples))
    category[n_samples//2:] = 1
    for i in range(n_samples): 
        if i%1000==0:
            print(i)

        cur_category = Y_train[i]

        img = random.choice(X_train[Y_train==cur_category]).reshape(28,28,1)
        _, img = cv2.threshold(img, .8, 1, cv2.THRESH_BINARY)
        arrPairs[0][i] = img.reshape(28,28,1)

        if category[i] == 1:
            img = random.choice(X_train[Y_train!=cur_category])
        else:
            img = random.choice(X_train[Y_train==cur_category])
        _, img = cv2.threshold(img, .8, 1, cv2.THRESH_BINARY)
        arrPairs[1][i] = img.reshape(28,28,1)
    arrPairs[0] = arrPairs[0]/255
    return arrPairs, category

培训结果

pairs, category = loadData()
model = siameseNet(INPUT_SHAPE)
model.compile(optimizer=Adam(lr=0.0005),loss="binary_crossentropy")
model.fit(pairs, category,  epochs=5, verbose=1, validation_split=0.2)

Train on 16000 samples, validate on 4000 samples
Epoch 1/5
16000/16000 [==============================] - 6s 353us/step - loss: 0.6660 - val_loss: 0.9474
Epoch 2/5
16000/16000 [==============================] - 5s 287us/step - loss: 0.6628 - val_loss: 0.9335
Epoch 3/5
16000/16000 [==============================] - 5s 287us/step - loss: 0.6627 - val_loss: 0.8487
Epoch 4/5
16000/16000 [==============================] - 5s 287us/step - loss: 0.6625 - val_loss: 0.9954
Epoch 5/5
16000/16000 [==============================] - 5s 288us/step - loss: 0.6616 - val_loss: 0.9133

但是无论我怎么尝试,损失都不会减少,因此预测不正确。

我尝试更改激活,增加和减少网络复杂性(添加和删除层,以及增加和减少Conv2D参数),但是这些都不起作用,所以我猜测它是一个我所缺少的建筑问题

更新: 用于测试的行:

test_pairs = [np.zeros((2, INPUT_SHAPE[0], INPUT_SHAPE[1],INPUT_SHAPE[2])) for i in range(2)]
test_pairs[0][0] = cv2.cvtColor(cv2.imread('test1_samenumber.png'), cv2.COLOR_BGR2GRAY).reshape(28,28,1); 
test_pairs[1][0] = cv2.cvtColor(cv2.imread('test2_samenumber.png'), cv2.COLOR_BGR2GRAY).reshape(28,28,1);

pred = model.predict(test_pairs)
print(pred)

输出的内容:

[[0.32230237]
 [0.44603676]]

1 个答案:

答案 0 :(得分:1)

在加载数据时,您有不必要的规范化。具体来说,对于第一对图像,不需要时将除以255。用cv2.threshold设置阈值后,输出值本质上为0或1,因此进一步除以255将使动态范围小于第二对图像,这可能会导致学习如何区分两个图像时出现问题。我已通过注释掉arrPairs[0] = arrPairs[0] / 255语句来删除此规范化。

训练完您的网络后,我遍历了每一对,并检查了输出预测。本质上,如果类别为1,并且网络(您的S型层)生成的预测大于0.5,则将其视为正确的预测。同样,当我看到类别为0并且生成的预测小于0.5时,这也是正确的。

correct = 0
for i in range(len(pairs[0])):
    output = model.predict([pairs[0][i][None], pairs[1][i][None]])[0][0]
    if (category[i] == 1 and output >= 0.5) or (category[i] == 0 and output < 0.5):
        correct += 1

print(correct / len(pairs[0]))

我在这里得到99.26%的准确度,这意味着在20000个样本中,有0.74%被错误分类。我会说这是一个很好的结果。

可重现的Google Colab笔记本可以在这里找到:https://colab.research.google.com/drive/10Q6rjuiytRSump2nulW5UhXY_PJh1eor