我使用 Keras 创建了一个模型,并使用 EMNIST 数据集进行了训练。我有以下代码:
导入模块
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
安装数据集
!pip install extra-keras-datasets
from extra_keras_datasets import emnist
(train_images, train_labels), (test_images, test_labels) = emnist.load_data(type='letters')
处理数据
train_images = train_images / 255.0
test_images = test_images / 255.0
创建模型
model = keras.Sequential([
keras.layers.Flatten(input_shape=(28, 28)),
keras.layers.Dense(128, activation='relu'),
keras.layers.Dense(27, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(train_images, train_labels, epochs=5)
然后,我下载了模型。
然后,我有以下代码显示一个窗口来写一封信,并希望预测它属于什么类。
导入模块
import tensorflow as tf
from tensorflow import keras
from keras.models import load_model
import tkinter as tk
import pyscreenshot as ImageGrab
import numpy as np
加载模型
model = load_model('path/model1.h5')
函数进行预测
def predict_digit(img):
img = img.resize((28,28))
img = img.convert('L')
img = np.array(img)
img = img.reshape(1,28,28,1)
img = img / 255.0
res = model.predict([img])[0]
return np.argmax(res), max(res)
界面代码
class App(tk.Tk):
def __init__(self):
tk.Tk.__init__(self)
self.x = self.y = 0
self.canvas = tk.Canvas(self, width=300, height=300, bg = "white", cursor="cross")
self.label = tk.Label(self, text="Thinking..", font=("Arial", 48))
self.classify_btn = tk.Button(self, text = "Recognise", command = self.classify_handwriting)
self.button_clear = tk.Button(self, text = "Clear", command = self.clear_all)
self.canvas.grid(row=0, column=0, pady=2, sticky=W, )
self.label.grid(row=0, column=1,pady=2, padx=2)
self.classify_btn.grid(row=1, column=1, pady=2, padx=2)
self.button_clear.grid(row=1, column=0, pady=2)
self.canvas.bind("<B1-Motion>", self.draw_lines)
def getter(self):
widget = self.canvas
x = self.winfo_rootx() + widget.winfo_x()
y = self.winfo_rooty() + widget.winfo_y()
x1 = x + widget.winfo_width()
y1 = y + widget.winfo_height()
im = ImageGrab.grab(bbox=(x,y,x1,y1))
return im
def clear_all(self):
self.canvas.delete("all")
def classify_handwriting(self):
img = self.getter()
digit, acc = predict_digit(img)
self.label.configure(text= str(digit)+', '+ str(int(acc*100))+'%')
def draw_lines(self, event):
self.x = event.x
self.y = event.y
r=8
self.canvas.create_oval(self.x-r, self.y-r, self.x + r, self.y + r, fill='black')
app = App()
mainloop()
然而,每次我尝试写一封要预测的信时,它总是显示班级编号 17。我错过了什么?
此外,我使用以下内容作为该程序的参考:
提前致谢。
答案 0 :(得分:-1)
我看到的问题是,您仅使用 2 个常规 Dense 层作为图像识别模型。这不会为该模型产生足够的可训练参数。由于您正在解决涉及图像识别的问题,因此您应该在密集层和展平层之前使用卷积层。一个二维卷积层的例子:
model.add(Conv2D(64, (3, 3), activation='relu'))
。