我正在研究有关使用python进行机器学习的学校项目。我创建了一个带张量流的线性分类器,并且已经学习了MNIST数据集,准确度超过90%。
预测数据集测试数据工作正常,但问题是当我想要导入不是来自测试数据集的数据时(可能只是在绘画中创建的图像)。
我为我的演示文稿创建了一个简单的GUI,它也可以正常使用它,但只是没有例如.png图像。
我尝试过Pillow,但看起来效果不好。
你能帮帮我吗?我会接受任何建议。非常感谢。这是tensorflow代码:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from PIL import Image
learn = tf.contrib.learn
tf.logging.set_verbosity(tf.logging.ERROR)
global i, test_labels
i = 0
def display(i):
img = test_data[i]
plt.title('Example %d, label %d' % (i, test_labels[i]))
plt.imshow(img.reshape((28, 28)), cmap=plt.cm.gray_r)
plt.show()
global mnist
mnist = learn.datasets.load_dataset("mnist")
test_data = mnist.test.images
test_labels = np.array(mnist.test.labels, dtype=np.int32)
def train_me(max_examples, batch, step):
data = mnist.train.images
labels = np.array(mnist.train.labels, dtype=np.int32)
data = data[:max_examples]
labels = labels[:max_examples]
feature_columns = learn.infer_real_valued_columns_from_input(data)
cls = learn.LinearClassifier(feature_columns=feature_columns, n_classes=10)
cls.fit(data, labels, batch_size=batch, steps=step)
return cls
def test_me(cls):
im = Image.open("dva-test.png")
global prediction
prediction = cls.predict(im, as_iterable=False)
这是GUI代码:
import sys
import digits as dig
from PyQt5.QtWidgets import (QApplication, QWidget, QToolTip, QPushButton, QMessageBox, QDesktopWidget, QMainWindow,
QLabel, QAction, QFileDialog)
from PyQt5.QtGui import QIcon
class Gui(QMainWindow):
def __init__(self):
super().__init__()
self.init_ui()
def init_ui(self):
self.setFixedSize(500, 200)
self.center()
self.statusBar().showMessage('Not trained')
exAct = QAction('Exit', self)
exAct.setShortcut('Ctrl+Q')
exAct.triggered.connect(self.close)
impAct = QAction('Import picture', self)
impAct.setShortcut('Ctrl+I')
impAct.triggered.connect(self.file_import)
menubar = self.menuBar()
fileMenu = menubar.addMenu('&File')
fileMenu.addAction(impAct)
fileMenu.addAction(exAct)
trainBtn = QPushButton('Train', self)
trainBtn.resize(trainBtn.sizeHint())
trainBtn.move(155, 120)
trainBtn.clicked.connect(self.trainning)
testBtn = QPushButton('Test', self)
testBtn.resize(trainBtn.sizeHint())
testBtn.move(255, 120)
testBtn.clicked.connect(self.testing)
text = QLabel("Please import file and train the classifier before testing.", self)
text.resize(text.sizeHint())
text.move(120, 40)
self.setWindowIcon(QIcon('icon.png'))
self.setWindowTitle('Digits')
self.show()
def trainning(self):
global classifier
classifier = dig.train_me(10000, 100, 1000)
classifier.evaluate(dig.test_data, dig.test_labels)
self.statusBar().showMessage('Accuracy: ' +
str(classifier.evaluate(dig.test_data, dig.test_labels)['accuracy']))
def testing(self):
dig.i = 2
dig.test_me(classifier)
self.statusBar().showMessage("Predicted %d, label: %d" % (dig.prediction, dig.test_labels[dig.i]))
def file_import(self):
name = QFileDialog.getOpenFileName(self, 'Import File')
print(name)
def closeEvent(self, event):
reply = QMessageBox.question(self, 'Message', "Are you sure you want to exit ?",
QMessageBox.Yes | QMessageBox.No, QMessageBox.No)
if reply == QMessageBox.Yes:
event.accept()
else:
event.ignore()
def center(self):
qr = self.frameGeometry()
cp = QDesktopWidget().availableGeometry().center()
qr.moveCenter(cp)
self.move(qr.topLeft())
if __name__ == '__main__':
app = QApplication(sys.argv)
ui = Gui()
sys.exit(app.exec_())
答案 0 :(得分:1)
解决:
Tensorflow只接受一维数组,我的图像是3D数组。形状= [28,28,3]。所以我删除了RGB维度并对2D数组进行了调整。
我的结果导入了Tensorflow分类器,但我意识到我需要反转颜色,所以数组中的每个零都应该等于1,每1到0。
以下是代码:
im = mpimg.imread('dva-test.png')
im = im[:, :, 0]
im = im.ravel()
for j in range(len(im)):
if im[j] == 0:
im[j] = 1
elif im[j] == 1:
im[j] = 0
global prediction
prediction = cls.predict(np.array([im], dtype=float), as_iterable=False)