Tensorflow(3):创建画板,实时在线手写体识别--终极篇(PyQt5)

   学习Tensorflow已经有一段时间了,就想能不能实现手写体的实时在线识别,于是进行了一番探索。本文源代码可以在这里下载【Python3+PyQt5+Tensorflow】创建画板,实时在线手写体识别】
   用到的库:Python3.6.5 + PyQt5 + PIL,编写GUI程序,实现手写体实时在线识别。最终实现的效果如下图所示,在方框里用鼠标手写数字,左下角显示识别结果,准确率可以达到99.2%。


Tensorflow(3):创建画板,实时在线手写体识别--终极篇(PyQt5)_第1张图片

1.画板GUI及模型加载(MyMnistWindow.py)

  使用PyQt5制作了一个交互式画板,可以用鼠标在上面写字。画板的程序部分参考了【Python3使用PyQt5制作简单的画板/手写板】。

'''
    功能:
        利用训练好的模型,进行实时手写体识别
    作者:yuhansgg
    博客: https://blog.csdn.net/u011389706
    日期: 2018/08/06
'''
import tensorflow as tf
from PyQt5.QtWidgets import (QWidget, QPushButton, QLabel)
from PyQt5.QtGui import (QPainter, QPen, QFont)
from PyQt5.QtCore import Qt
from PIL import ImageGrab, Image

class MyMnistWindow(QWidget):

    def __init__(self):
        super(MyMnistWindow, self).__init__()

        self.resize(284, 330)  # resize设置宽高
        self.move(100, 100)    # move设置位置
        self.setWindowFlags(Qt.FramelessWindowHint)  # 窗体无边框
        #setMouseTracking设置为False,否则不按下鼠标时也会跟踪鼠标事件
        self.setMouseTracking(False)

        self.pos_xy = []  #保存鼠标移动过的点

        # 添加一系列控件
        self.label_draw = QLabel('', self)
        self.label_draw.setGeometry(2, 2, 280, 280)
        self.label_draw.setStyleSheet("QLabel{border:1px solid black;}")
        self.label_draw.setAlignment(Qt.AlignCenter)

        self.label_result_name = QLabel('识别结果:', self)
        self.label_result_name.setGeometry(2, 290, 60, 35)
        self.label_result_name.setAlignment(Qt.AlignCenter)

        self.label_result = QLabel(' ', self)
        self.label_result.setGeometry(64, 290, 35, 35)
        self.label_result.setFont(QFont("Roman times", 8, QFont.Bold))
        self.label_result.setStyleSheet("QLabel{border:1px solid black;}")
        self.label_result.setAlignment(Qt.AlignCenter)

        self.btn_recognize = QPushButton("识别", self)
        self.btn_recognize.setGeometry(110, 290, 50, 35)
        self.btn_recognize.clicked.connect(self.btn_recognize_on_clicked)

        self.btn_clear = QPushButton("清空", self)
        self.btn_clear.setGeometry(170, 290, 50, 35)
        self.btn_clear.clicked.connect(self.btn_clear_on_clicked)

        self.btn_close = QPushButton("关闭", self)
        self.btn_close.setGeometry(230, 290, 50, 35)
        self.btn_close.clicked.connect(self.btn_close_on_clicked)

    def paintEvent(self, event):
        painter = QPainter()
        painter.begin(self)
        pen = QPen(Qt.black, 30, Qt.SolidLine)
        painter.setPen(pen)

        if len(self.pos_xy) > 1:
            point_start = self.pos_xy[0]
            for pos_tmp in self.pos_xy:
                point_end = pos_tmp

                if point_end == (-1, -1):
                    point_start = (-1, -1)
                    continue
                if point_start == (-1, -1):
                    point_start = point_end
                    continue

                painter.drawLine(point_start[0], point_start[1], point_end[0], point_end[1])
                point_start = point_end
        painter.end()

    def mouseMoveEvent(self, event):
        '''
            按住鼠标移动事件:将当前点添加到pos_xy列表中
        '''
        #中间变量pos_tmp提取当前点
        pos_tmp = (event.pos().x(), event.pos().y())
        #pos_tmp添加到self.pos_xy中
        self.pos_xy.append(pos_tmp)

        self.update()

    def mouseReleaseEvent(self, event):
        '''
            重写鼠标按住后松开的事件
            在每次松开后向pos_xy列表中添加一个断点(-1, -1)
        '''
        pos_test = (-1, -1)
        self.pos_xy.append(pos_test)

        self.update()

    def btn_recognize_on_clicked(self):
        bbox = (104, 104, 380, 380)
        im = ImageGrab.grab(bbox)    # 截屏,手写数字部分
        im = im.resize((28, 28), Image.ANTIALIAS)  # 将截图转换成 28 * 28 像素

        recognize_result = self.recognize_img(im)  # 调用识别函数

        self.label_result.setText(str(recognize_result))  # 显示识别结果
        self.update()

    def btn_clear_on_clicked(self):
        self.pos_xy = []
        self.label_result.setText('')
        self.update()

    def btn_close_on_clicked(self):
        self.close()

    def recognize_img(self, img):  # 手写体识别函数
        myimage = img.convert('L')  # 转换成灰度图
        tv = list(myimage.getdata())  # 获取图片像素值
        tva = [(255 - x) * 1.0 / 255.0 for x in tv]  # 转换像素范围到[0 1], 0是纯白 1是纯黑

        init = tf.global_variables_initializer()
        saver = tf.train.Saver  

        with tf.Session() as sess:
            sess.run(init)
            saver = tf.train.import_meta_graph('minst_cnn_model.ckpt.meta')  # 载入模型结构
            saver.restore(sess, 'minst_cnn_model.ckpt')  # 载入模型参数

            graph = tf.get_default_graph()  # 加载计算图
            x = graph.get_tensor_by_name("x:0")  # 从模型中读取占位符变量
            keep_prob = graph.get_tensor_by_name("keep_prob:0")
            y_conv = graph.get_tensor_by_name("y_conv:0")  # 关键的一句  从模型中读取占位符变量

            prediction = tf.argmax(y_conv, 1)
            predint = prediction.eval(feed_dict={x: [tva], keep_prob: 1.0}, session=sess)  # feed_dict输入数据给placeholder占位符
            print(predint[0])
        return predint[0]

   识别时,先利用函数ImageGrab.grab(bbox),对屏幕画板部分进行截图。然后对截图进行预处理(缩放到28*28像素,转换成灰度图等)。
   在最后,最重要的手写体识别函数里recognize_img(self, img),我们调用了已经训练好的模型minst_cnn_model.ckpt,具体模型训练过程,参见【Tensorflow(2):MNIST识别自己手写的数字–进阶篇(CNN)】

2.主程序(main.py)

   实例化我们上面定义的窗体类MyMnistWindow,实现窗体显示。

import sys
from PyQt5.QtWidgets import QApplication
from MyMnistWindow import MyMnistWindow

if __name__ == "__main__":
    app = QApplication(sys.argv)
    mymnist = MyMnistWindow()
    mymnist.show()
    app.exec_()

3.实验结果

   最终测试结果如下所示,左下角显示识别结果。可以看到,基本都能正确识别:

Tensorflow(3):创建画板,实时在线手写体识别--终极篇(PyQt5)_第2张图片Tensorflow(3):创建画板,实时在线手写体识别--终极篇(PyQt5)_第3张图片Tensorflow(3):创建画板,实时在线手写体识别--终极篇(PyQt5)_第4张图片
Tensorflow(3):创建画板,实时在线手写体识别--终极篇(PyQt5)_第5张图片Tensorflow(3):创建画板,实时在线手写体识别--终极篇(PyQt5)_第6张图片Tensorflow(3):创建画板,实时在线手写体识别--终极篇(PyQt5)_第7张图片
Tensorflow(3):创建画板,实时在线手写体识别--终极篇(PyQt5)_第8张图片Tensorflow(3):创建画板,实时在线手写体识别--终极篇(PyQt5)_第9张图片Tensorflow(3):创建画板,实时在线手写体识别--终极篇(PyQt5)_第10张图片

4.注意事项

    若识别准确率不高,一般是由于我们手写数字和训练数据相差太大导致的
    一般原因是:训练集是西方的手写数字,和中国的手写数字习惯不同。下面是官方的训练数据中的部分数字。


Tensorflow(3):创建画板,实时在线手写体识别--终极篇(PyQt5)_第11张图片

    在画图时,笔法尽量和上面训练集保持一致,就会得到较高的识别率!
    本文源代码可以在这里下载【Python3+PyQt5+Tensorflow】创建画板,实时在线手写体识别】
    是以为记!

你可能感兴趣的:(Deep,Learning)