学习Tensorflow已经有一段时间了,就想能不能实现手写体的实时在线识别,于是进行了一番探索。本文源代码可以在这里下载【Python3+PyQt5+Tensorflow】创建画板,实时在线手写体识别】
用到的库:Python3.6.5 + PyQt5 + PIL,编写GUI程序,实现手写体实时在线识别。最终实现的效果如下图所示,在方框里用鼠标手写数字,左下角显示识别结果,准确率可以达到99.2%。
使用PyQt5制作了一个交互式画板,可以用鼠标在上面写字。画板的程序部分参考了【Python3使用PyQt5制作简单的画板/手写板】。
'''
功能:
利用训练好的模型,进行实时手写体识别
作者:yuhansgg
博客: https://blog.csdn.net/u011389706
日期: 2018/08/06
'''
import tensorflow as tf
from PyQt5.QtWidgets import (QWidget, QPushButton, QLabel)
from PyQt5.QtGui import (QPainter, QPen, QFont)
from PyQt5.QtCore import Qt
from PIL import ImageGrab, Image
class MyMnistWindow(QWidget):
def __init__(self):
super(MyMnistWindow, self).__init__()
self.resize(284, 330) # resize设置宽高
self.move(100, 100) # move设置位置
self.setWindowFlags(Qt.FramelessWindowHint) # 窗体无边框
#setMouseTracking设置为False,否则不按下鼠标时也会跟踪鼠标事件
self.setMouseTracking(False)
self.pos_xy = [] #保存鼠标移动过的点
# 添加一系列控件
self.label_draw = QLabel('', self)
self.label_draw.setGeometry(2, 2, 280, 280)
self.label_draw.setStyleSheet("QLabel{border:1px solid black;}")
self.label_draw.setAlignment(Qt.AlignCenter)
self.label_result_name = QLabel('识别结果:', self)
self.label_result_name.setGeometry(2, 290, 60, 35)
self.label_result_name.setAlignment(Qt.AlignCenter)
self.label_result = QLabel(' ', self)
self.label_result.setGeometry(64, 290, 35, 35)
self.label_result.setFont(QFont("Roman times", 8, QFont.Bold))
self.label_result.setStyleSheet("QLabel{border:1px solid black;}")
self.label_result.setAlignment(Qt.AlignCenter)
self.btn_recognize = QPushButton("识别", self)
self.btn_recognize.setGeometry(110, 290, 50, 35)
self.btn_recognize.clicked.connect(self.btn_recognize_on_clicked)
self.btn_clear = QPushButton("清空", self)
self.btn_clear.setGeometry(170, 290, 50, 35)
self.btn_clear.clicked.connect(self.btn_clear_on_clicked)
self.btn_close = QPushButton("关闭", self)
self.btn_close.setGeometry(230, 290, 50, 35)
self.btn_close.clicked.connect(self.btn_close_on_clicked)
def paintEvent(self, event):
painter = QPainter()
painter.begin(self)
pen = QPen(Qt.black, 30, Qt.SolidLine)
painter.setPen(pen)
if len(self.pos_xy) > 1:
point_start = self.pos_xy[0]
for pos_tmp in self.pos_xy:
point_end = pos_tmp
if point_end == (-1, -1):
point_start = (-1, -1)
continue
if point_start == (-1, -1):
point_start = point_end
continue
painter.drawLine(point_start[0], point_start[1], point_end[0], point_end[1])
point_start = point_end
painter.end()
def mouseMoveEvent(self, event):
'''
按住鼠标移动事件:将当前点添加到pos_xy列表中
'''
#中间变量pos_tmp提取当前点
pos_tmp = (event.pos().x(), event.pos().y())
#pos_tmp添加到self.pos_xy中
self.pos_xy.append(pos_tmp)
self.update()
def mouseReleaseEvent(self, event):
'''
重写鼠标按住后松开的事件
在每次松开后向pos_xy列表中添加一个断点(-1, -1)
'''
pos_test = (-1, -1)
self.pos_xy.append(pos_test)
self.update()
def btn_recognize_on_clicked(self):
bbox = (104, 104, 380, 380)
im = ImageGrab.grab(bbox) # 截屏,手写数字部分
im = im.resize((28, 28), Image.ANTIALIAS) # 将截图转换成 28 * 28 像素
recognize_result = self.recognize_img(im) # 调用识别函数
self.label_result.setText(str(recognize_result)) # 显示识别结果
self.update()
def btn_clear_on_clicked(self):
self.pos_xy = []
self.label_result.setText('')
self.update()
def btn_close_on_clicked(self):
self.close()
def recognize_img(self, img): # 手写体识别函数
myimage = img.convert('L') # 转换成灰度图
tv = list(myimage.getdata()) # 获取图片像素值
tva = [(255 - x) * 1.0 / 255.0 for x in tv] # 转换像素范围到[0 1], 0是纯白 1是纯黑
init = tf.global_variables_initializer()
saver = tf.train.Saver
with tf.Session() as sess:
sess.run(init)
saver = tf.train.import_meta_graph('minst_cnn_model.ckpt.meta') # 载入模型结构
saver.restore(sess, 'minst_cnn_model.ckpt') # 载入模型参数
graph = tf.get_default_graph() # 加载计算图
x = graph.get_tensor_by_name("x:0") # 从模型中读取占位符变量
keep_prob = graph.get_tensor_by_name("keep_prob:0")
y_conv = graph.get_tensor_by_name("y_conv:0") # 关键的一句 从模型中读取占位符变量
prediction = tf.argmax(y_conv, 1)
predint = prediction.eval(feed_dict={x: [tva], keep_prob: 1.0}, session=sess) # feed_dict输入数据给placeholder占位符
print(predint[0])
return predint[0]
识别时,先利用函数ImageGrab.grab(bbox),对屏幕画板部分进行截图。然后对截图进行预处理(缩放到28*28像素,转换成灰度图等)。
在最后,最重要的手写体识别函数里recognize_img(self, img),我们调用了已经训练好的模型minst_cnn_model.ckpt,具体模型训练过程,参见【Tensorflow(2):MNIST识别自己手写的数字–进阶篇(CNN)】
实例化我们上面定义的窗体类MyMnistWindow,实现窗体显示。
import sys
from PyQt5.QtWidgets import QApplication
from MyMnistWindow import MyMnistWindow
if __name__ == "__main__":
app = QApplication(sys.argv)
mymnist = MyMnistWindow()
mymnist.show()
app.exec_()
最终测试结果如下所示,左下角显示识别结果。可以看到,基本都能正确识别:
若识别准确率不高,一般是由于我们手写数字和训练数据相差太大导致的。
一般原因是:训练集是西方的手写数字,和中国的手写数字习惯不同。下面是官方的训练数据中的部分数字。
在画图时,笔法尽量和上面训练集保持一致,就会得到较高的识别率!。
本文源代码可以在这里下载【Python3+PyQt5+Tensorflow】创建画板,实时在线手写体识别】
是以为记!