改善由mnist数据集训练的神经网络的真实结果

时间:2019-12-30 18:27:39

标签: python machine-learning keras mnist handwriting-recognition

我已经使用mnist数据集使用keras构建了一个神经网络,现在我正尝试在实际手写数字的照片上使用它。当然,我并不期望结果会是完美的,但是我目前得到的结果仍有很大的改进空间。

对于初学者来说,我会用我最清晰的笔迹写的一些个人数字照片进行测试。它们是正方形的,并且具有与mnist数据集中的图像相同的尺寸和颜色。它们被保存在名为 individual_test 的文件夹中,例如: 7(2)_digit.jpg

网络经常非常确定错误的结果,我将举一个例子:

clearly a 7

我为这张照片得到的结果如下:

result:  3 . probabilities:  [1.9963557196245318e-10, 7.241294497362105e-07, 0.02658148668706417, 0.9726449251174927, 2.5416460047722467e-08, 2.6078915027483163e-08, 0.00019745019380934536, 4.8302300825753264e-08, 0.0005754049634560943, 2.8358477788259506e-09]

因此,网络有97%的人确定这是3,而这并不是唯一的情况。在38张照片中,只有16张被正确识别。令我感到震惊的是,尽管网络离正确的结果再远了,但它对结果的把握如此确定。

我该怎么做才能提高性能?我可以更好地准备图像吗?还是应该将自己的图像添加到训练数据中?如果是这样,我该怎么做?

编辑

这是在上面应用 prepare_image 后上面显示的图片的样子:

my picture after treatment

相比:这是mnist数据集提供的图片之一:

one of the mnist digits

它们看起来和我非常相似。我该如何改善呢?

这是我的代码:

# import keras and the MNIST dataset
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from keras.utils import np_utils
# numpy is necessary since keras uses numpy arrays
import numpy as np

# imports for pictures
import PIL

# imports for tests
import random
import os

class mnist_network():
    def __init__(self):
        """ load data, create and train model """
        # load data
        (X_train, y_train), (X_test, y_test) = mnist.load_data()
        # flatten 28*28 images to a 784 vector for each image
        num_pixels = X_train.shape[1] * X_train.shape[2]
        X_train = X_train.reshape((X_train.shape[0], num_pixels)).astype('float32')
        X_test = X_test.reshape((X_test.shape[0], num_pixels)).astype('float32')
        # normalize inputs from 0-255 to 0-1
        X_train = X_train / 255
        X_test = X_test / 255
        # one hot encode outputs
        y_train = np_utils.to_categorical(y_train)
        y_test = np_utils.to_categorical(y_test)
        num_classes = y_test.shape[1]


        # create model
        self.model = Sequential()
        self.model.add(Dense(num_pixels, input_dim=num_pixels, kernel_initializer='normal', activation='relu'))
        self.model.add(Dense(num_classes, kernel_initializer='normal', activation='softmax'))
        # Compile model
        self.model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

        # train the model
        self.model.fit(X_train, y_train, validation_data=(X_test, y_test), epochs=10, batch_size=200, verbose=2)

        self.train_img = X_train
        self.train_res = y_train
        self.test_img = X_test
        self.test_res = y_test


    def predict_result(self, img, num_pixels = None, show=False):
        """ predicts the number in a picture (vector) """
        assert type(img) == np.ndarray and img.shape == (784,)

        """if show:
            # show the picture!!!! some problem here
            plt.imshow(img, cmap='Greys')
            plt.show()"""

        num_pixels = img.shape[0]
        # the actual number
        res_number = np.argmax(self.model.predict(img.reshape(-1,num_pixels)), axis = 1)
        # the probabilities
        res_probabilities = self.model.predict(img.reshape(-1,num_pixels))

        return (res_number[0], res_probabilities.tolist()[0])    # we only need the first element since they only have one



    def prepare_image(self, img):
        """ prepares the partial images used in partial_img_rec by transforming them
            into numpy arrays that the network will be able to process """
        # convert to greyscale
        img = img.convert("L")
        # rescale image to 28 *28 dimension
        img = img.resize((28,28), PIL.Image.ANTIALIAS)
        # inverse colors since the training images have a black background
        img =  PIL.ImageOps.invert(img)
        # transform to vector
        img = np.asarray(img, "float32")
        img = img / 255.
        img[img < 0.5] = 0.
        # flatten image to 28*28 = 784 vector
        num_pixels = img.shape[0] * img.shape[1]
        img = img.reshape(num_pixels)


        return img


    def partial_img_rec(self, image, upper_left, lower_right, results=[], show = False):
        """ partial is a part of an image """
        left_x, left_y = upper_left
        right_x, right_y = lower_right

        print("current test part: ", upper_left, lower_right)
        print("results: ", results)
        # condition to stop recursion: we've reached the full width of the picture
        width, height = image.size
        if right_x > width:
            return results

        partial = image.crop((left_x, left_y, right_x, right_y))
        if show:
            partial.show()
        partial = self.prepare_image(partial)

        step = height // 10
        # is there a number in this part of the image? 

        res, prop = self.predict_result(partial)
        print("result: ", res, ". probabilities: ", prop)
        # only count this result if the network is at least 50% sure
        if prop[res] >= 0.5:        
            results.append(res)
            # step is 80% of the partial image's size (which is equivalent to the original image's height) 
            step = int(height * 0.8)
            print("found valid result")
        else:
            # if there is no number found we take smaller steps
            step = height // 20 
        print("step: ", step)
        # recursive call with modified positions ( move on step variables )
        return self.partial_img_rec(image, (left_x + step, left_y), (right_x + step, right_y), results = results)

    def individual_digits(self, img):
        """ uses partial_img_rec to predict individual digits in square images """
        assert type(img) == PIL.JpegImagePlugin.JpegImageFile or type(img) == PIL.PngImagePlugin.PngImageFile or type(img) == PIL.Image.Image

        return self.partial_img_rec(img, (0,0), (img.size[0], img.size[1]), results=[])

    def test_individual_digits(self):
        """ test partial_img_rec with some individual digits (shape: square) 
            saved in the folder 'individual_test' following the pattern 'number_digit.jpg' """
        cnt_right, cnt_wrong = 0,0
        folder_content = os.listdir(".\individual_test")

        for imageName in folder_content:
            # image file must be a jpg or png
            assert imageName[-4:] == ".jpg" or imageName[-4:] == ".png"
            correct_res = int(imageName[0])
            image = PIL.Image.open(".\\individual_test\\" + imageName).convert("L")
            # only square images in this test
            if image.size[0]  != image.size[1]:
                print(imageName, " has the wrong proportions: ", image.size,". It has to be a square.")
                continue 
            predicted_res = self.individual_digits(image)

            if predicted_res == []:
                print("No prediction possible for ", imageName)
            else:
                predicted_res = predicted_res[0]

            if predicted_res != correct_res:
                print("error in partial_img-rec! Predicted ", predicted_res, ". The correct result would have been ", correct_res)
                cnt_wrong += 1
            else:
                cnt_right += 1
                print("correctly predicted ",imageName)
        print(cnt_right, " out of ", cnt_right + cnt_wrong," digits were correctly recognised. The success rate is therefore ", (cnt_right / (cnt_right + cnt_wrong)) * 100," %.")

    def multiple_digits(self, img):
        """ takes as input an image without unnecessary whitespace surrounding the digits """

        #assert type(img) == myImage
        width, height = img.size
        # start with the first square part of the image
        res_list = self.partial_img_rec(img, (0,0),(height ,height), results = [])
        res_str = ""
        for elem in res_list:
            res_str += str(elem)
        return res_str

    def test_multiple_digits(self):
        """ tests the function 'multiple_digits' using some images saved in the folder 'multi_test'.
            These images contain multiple handwritten digits without much whitespac surrounding them.
            The correct solutions are saved in the files' names followed by the characte '_'. """

        cnt_right, cnt_wrong = 0,0
        folder_content = os.listdir(".\multi_test")
        for imageName in folder_content:
            # image file must be a jpg or png
            assert imageName[-4:] == ".jpg" or imageName[-4:] == ".png"            
            image = PIL.Image.open(".\\multi_test\\" + imageName).convert("L")

            correct_res = imageName.split("_")[0]
            predicted_res = self.multiple_digits(image)
            if correct_res == predicted_res:
                cnt_right += 1
            else:
                cnt_wrong += 1
                print("Error in multiple_digits! The network predicted ", predicted_res, " but the correct result would have been ", correct_res)

        print("The network predicted correctly ", cnt_right, " out of ", cnt_right + cnt_wrong, " pictures. That's a success rate of ", cnt_right / (cnt_right + cnt_wrong) * 100, "%.")

network = mnist_network()
# 7(2).digit.jpg is the image shown above
network.individual_digits(PIL.Image.open(".\individual_test\\7(2)_digit.jpg"))

3 个答案:

答案 0 :(得分:1)

您在MNIST数据集上的测试分数是多少? 我想到的一件事是您的图像缺少阈值,

阈值处理是将某个像素以下的像素值设置为零的技术,请参阅OpenCV阈值示例,您可能需要使用反阈值并再次检查结果。

做,通知是否有进展。

答案 1 :(得分:0)

您遇到的主要问题是,您正在测试的图像与MNIST图像不同,可能是由于已完成图像的准备工作,在您应用prepare_image之后,您能否显示来自正在测试的图像的图像?在上面。

答案 2 :(得分:0)

更新:

您可以通过以下三种方法在此特定任务中获得更好的性能:

  1. 使用卷积网络,因为它在处理具有空间数据(如图像)的任务时表现更好,并且像这样的生成器更具生成性。
  2. 使用或创建和/或生成更多类型的图片,并训练您的网络,并与您的网络一起学习。
  3. 预处理,以使您的图像与以前训练网络的原始MNIST图像更好地对齐。

我刚刚做了一个实验。我检查了每个代表一个数字的MNIST图像。我拍摄了您的图像,并进行了我之前向您建议的一些预处理,例如:

1。设置了一些阈值,但是只是向下消除了背景噪声,因为原始MNIST数据仅对空白背景具有一些最小阈值:

image[image < 0.1] = 0.

2。令人惊讶的是,图像内部数字的大小已被证明是至关重要的,因此我按比例缩放了28 x 28图像内部的数字。我们在数字周围还有更多填充。

3。。由于来自keras的MNIST数据也反转了,所以我反转了图像。

image = ImageOps.invert(image)

4。。最后,像我们在培训中一样,用来缩放数据:

image = image / 255.

预处理后,我使用MNIST数据集训练了模型,该数据集的参数为epochs=12, batch_size=200,结果为:

enter image description here enter image description here

结果: 1 ,概率: 0.6844741106033325

 result:  **1** . probabilities:  [2.0584749904628552e-07, 0.9875971674919128, 5.821426839247579e-06, 4.979299319529673e-07, 0.012240586802363396, 1.1566483948399764e-07, 2.382085284580171e-08, 0.00013023221981711686, 9.620113416985987e-08, 2.5273093342548236e-05]

enter image description here enter image description here

结果: 6 ,概率: 0.9221984148025513

result:  6 . probabilities:  [9.130864782491699e-05, 1.8290626258021803e-07, 0.00020504613348748535, 2.1564576968557958e-07, 0.0002401985548203811, 0.04510130733251572, 0.9221984148025513, 1.9014490248991933e-07, 0.03216308355331421, 3.323434683011328e-08]

enter image description here enter image description here

结果: 7 ,概率: 0.7105212807655334 注意:

result:  7 . probabilities:  [1.0372193770535887e-08, 7.988557626958936e-06, 0.00031014863634482026, 0.0056108818389475346, 2.434678014751057e-09, 3.2280522077599016e-07, 1.4190952857262573e-09, 0.9940618872642517, 1.612859932720312e-06, 7.102244126144797e-06]

您的电话号码 9 有点棘手:

enter image description here enter image description here

我发现带有MNIST数据集的模型获得了关于 9 的两个主要“特征”。上部和下部。与图像上一样,具有良好圆形的上部不是 9 ,而是针对MNIST数据集训练的模型的 3 。根据MNIST数据集, 9 的下部大部分是拉直曲线。因此,由于MNIST样本的缘故,基本上,理想形状的 9 始终是模型的 3 ,除非您再次用足够数量的形状的样本来训练模型> 9 。为了检查我的想法,我用 9 s做了一个子实验:

我的 9 的上部倾斜(根据MNIST,大多数情况下 9 都可以),但底部略微卷曲( 9 不能根据MNIST):

enter image description here

结果: 9 ,概率: 0.5365301370620728

我的 9 ,其上部倾斜(根据MNIST,大多数情况下 9 都可以),并且笔直的底部(按照 9 就可以了) MNIST):

enter image description here

结果: 9 ,概率: 0.923724353313446

您的 9 具有错误解释的形状属性:

enter image description here

结果: 3 ,概率: 0.8158268928527832

result:  3 . probabilities:  [9.367801249027252e-05, 3.9978775021154433e-05, 0.0001467708352720365, 0.8158268928527832, 0.0005801069783046842, 0.04391581565141678, 6.44062723154093e-08, 7.099170943547506e-06, 0.09051419794559479, 0.048875387758016586]


最后只是证明图像缩放(填充)重要性的证据,我上面提到的很重要:

enter image description here

结果: 3 ,概率: 0.9845736622810364

enter image description here

结果: 9 ,概率: 0.923724353313446

因此我们可以看到,如果模型内部形状过大且填充尺寸较小,则模型会解释并解释为 3

我认为使用CNN可以获得更好的性能,但是采样和预处理方式对于在ML任务中获得最佳性能始终至关重要。

希望对您有帮助。

更新2:

我发现了另一个问题,我也检查并证明是正确的,数字在图像内部的放置也很关键,这对于这种类型的NN是有意义的。一个很好的例子,在MNIST数据集中居中放置的数字 7 9 ,如果我们放置新的数字,则图像底部附近会导致较难或易碎的分类用于在图像中心进行分类。我检查了将 7 9 移至底部的理论,因此在图像顶部保留了更多位置,结果几乎是 100%准确性。 由于这是一个 spatial 类型的问题,我想通过 CNN 我们可以更有效地消除它。但是,如果MNIST被指定为居中,那会更好,或者我们可以通过编程方式避免出现此问题。