Python,Tensorflow导入非数据集图像

时间:2018-03-29 20:00:04

标签: python tensorflow machine-learning pyqt5

我正在研究有关使用python进行机器学习的学校项目。我创建了一个带张量流的线性分类器,并且已经学习了MNIST数据集,准确度超过90%。

预测数据集测试数据工作正常,但问题是当我想要导入不是来自测试数据集的数据时(可能只是在绘画中创建的图像)。

我为我的演示文稿创建了一个简单的GUI,它也可以正常使用它,但只是没有例如.png图像。

我尝试过Pillow,但看起来效果不好。

你能帮帮我吗?我会接受任何建议。非常感谢。

这是tensorflow代码:

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from PIL import Image
learn = tf.contrib.learn
tf.logging.set_verbosity(tf.logging.ERROR)

global i, test_labels
i = 0


def display(i):
   img = test_data[i]
   plt.title('Example %d, label %d' % (i, test_labels[i]))
   plt.imshow(img.reshape((28, 28)), cmap=plt.cm.gray_r)
   plt.show()


global mnist
mnist = learn.datasets.load_dataset("mnist")
test_data = mnist.test.images
test_labels = np.array(mnist.test.labels, dtype=np.int32)


def train_me(max_examples, batch, step):

   data = mnist.train.images
   labels = np.array(mnist.train.labels, dtype=np.int32)

   data = data[:max_examples]
   labels = labels[:max_examples]

   feature_columns = learn.infer_real_valued_columns_from_input(data)
   cls = learn.LinearClassifier(feature_columns=feature_columns,    n_classes=10)
   cls.fit(data, labels, batch_size=batch, steps=step)

   return cls


def test_me(cls):

   im = Image.open("dva-test.png")

   global prediction
   prediction = cls.predict(im, as_iterable=False)

这是GUI代码:

import sys
import digits as dig

from PyQt5.QtWidgets import (QApplication, QWidget, QToolTip,  QPushButton, QMessageBox, QDesktopWidget, QMainWindow,
                         QLabel, QAction, QFileDialog)
from PyQt5.QtGui import QIcon


class Gui(QMainWindow):
    def __init__(self):
        super().__init__()

        self.init_ui()

    def init_ui(self):

        self.setFixedSize(500, 200)
        self.center()
        self.statusBar().showMessage('Not trained')

        exAct = QAction('Exit', self)
        exAct.setShortcut('Ctrl+Q')
        exAct.triggered.connect(self.close)

        impAct = QAction('Import picture', self)
        impAct.setShortcut('Ctrl+I')
        impAct.triggered.connect(self.file_import)

        menubar = self.menuBar()
        fileMenu = menubar.addMenu('&File')
        fileMenu.addAction(impAct)
        fileMenu.addAction(exAct)

        trainBtn = QPushButton('Train', self)
        trainBtn.resize(trainBtn.sizeHint())
        trainBtn.move(155, 120)
        trainBtn.clicked.connect(self.trainning)

        testBtn = QPushButton('Test', self)
        testBtn.resize(trainBtn.sizeHint())
        testBtn.move(255, 120)
        testBtn.clicked.connect(self.testing)

        text = QLabel("Please import file and train the classifier before testing.", self)
        text.resize(text.sizeHint())
        text.move(120, 40)

        self.setWindowIcon(QIcon('icon.png'))
        self.setWindowTitle('Digits')

        self.show()

    def trainning(self):
        global classifier
        classifier = dig.train_me(10000, 100, 1000)

        classifier.evaluate(dig.test_data, dig.test_labels)
    self.statusBar().showMessage('Accuracy: ' +
                                 str(classifier.evaluate(dig.test_data,        dig.test_labels)['accuracy']))

    def testing(self):
        dig.i = 2
        dig.test_me(classifier)
        self.statusBar().showMessage("Predicted %d, label: %d" % (dig.prediction, dig.test_labels[dig.i]))

    def file_import(self):
            name = QFileDialog.getOpenFileName(self, 'Import File')
            print(name)

    def closeEvent(self, event):

        reply = QMessageBox.question(self, 'Message', "Are you sure you want to exit ?",
                                 QMessageBox.Yes | QMessageBox.No, QMessageBox.No)
        if reply == QMessageBox.Yes:
            event.accept()
        else:
            event.ignore()

    def center(self):

        qr = self.frameGeometry()
        cp = QDesktopWidget().availableGeometry().center()
        qr.moveCenter(cp)
        self.move(qr.topLeft())


if __name__ == '__main__':

    app = QApplication(sys.argv)
    ui = Gui()
    sys.exit(app.exec_())

1 个答案:

答案 0 :(得分:1)

解决:

Tensorflow只接受一维数组,我的图像是3D数组。形状= [28,28,3]。所以我删除了RGB维度并对2D数组进行了调整。

我的结果导入了Tensorflow分类器,但我意识到我需要反转颜色,所以数组中的每个零都应该等于1,每1到0。

以下是代码:

    im = mpimg.imread('dva-test.png')
    im = im[:, :, 0]
    im = im.ravel()
    for j in range(len(im)):
        if im[j] == 0:
            im[j] = 1
        elif im[j] == 1:
            im[j] = 0

     global prediction
     prediction = cls.predict(np.array([im], dtype=float), as_iterable=False)