基于Opencv库中SVM模块的MNIST手写字识别数据库识别

基于Opencv库中SVM模块的MNIST手写字识别数据库识别代码。

 

       MNIST的手写数字数据库,有60000例训练集, 10000个测试集。它是更大的数据集NIST的一个子集。 数字已经被size-normalized,是有固定大小的图像。

       官方地址:http://yann.lecun.com/exdb/mnist/

       有对这个数据的详细介绍。这里提一下,数据他是二进制文件格式存储的。不是图片格式,所以需要注意其数据存放格式,在opencv中进行数据格式转换。

数据格式:

TRAINING SET LABEL FILE (train-labels-idx1-ubyte):

[offset]  [type]         [value]            [description]

0000     32 bitinteger   0x00000801(2049)  magic number (MSB first)

0004     32 bitinteger   60000             number of items

0008    unsigned byte   ??                label

........

xxxx    unsigned byte   ??                label

The labels values are 0 to 9.

 

TRAINING SET IMAGE FILE (train-images-idx3-ubyte):

[offset]   [type]         [value]           [description]

0000     32 bitinteger   0x00000803(2051)   magicnumber

0004     32 bitinteger   60000            number of images

0008     32 bitinteger   28               number of rows

0012     32 bitinteger   28               number of columns

0016    unsigned byte   ??               pixel

0017    unsigned byte   ??               pixel

........

xxxx    unsigned byte   ??               pixel



SVM的识别错误率:

基于Opencv库中SVM模块的MNIST手写字识别数据库识别_第1张图片

界面:

基于Opencv库中SVM模块的MNIST手写字识别数据库识别_第2张图片

环境:opencv2.4+Ubuntu+linux


nistlabledata.h

#ifndef NISTLABLEDATA_H
#define NISTLABLEDATA_H
#include 
#include "nisttraindata.h"
#include "trainsformdata.h"

using namespace std;
using namespace cv;

class NISTLableData:public trainsformdata
{
public:
    NISTLableData();
    ~NISTLableData();
private:
    long int magic_number;
    long int number_of_items;
    static const long int magic_number_setted= 0x801;

    //friend long int NISTTrainData::trainsform_32bitDataform(long int &data,unsigned char* char_nums);

public:
    unsigned char magic_numbers[4],number_items[4];
    long int getnumber_of_items();
    bool check_magic_number();

    unsigned char lable;

    void trainsform_Dataforms()
    {
         trainsform_32bitDataform(magic_number,magic_numbers);
         trainsform_32bitDataform(number_of_items,number_items);
    }
    void show_Data()
    {
        cout<<"magic_number:"<

Nistlabledata.cpp

#include "nistlabledata.h"

NISTLableData::NISTLableData()
{
    magic_number=0;
    number_of_items=0;
}

NISTLableData::~NISTLableData()
{

}

Nisttraindata.h

#ifndef NISTTRAINDATA_H
#define NISTTRAINDATA_H

#include 
#include "trainsformdata.h"

using namespace std;
using namespace cv;
class NISTTrainData:public trainsformdata
{

public:
    NISTTrainData();
    ~NISTTrainData();

private:
    long int magic_number;
    long int number_of_images;
    long int number_of_rows;
    long int number_of_columns;

    static const long int magic_number_setted= 0x803;




public:
    static const int image_row= 20;
    static const int image_col= 20;
    unsigned char magicNum[4], ccount[4], crows[4], ccols[4];
    void GetROI(Mat& src, Mat& dst);
    friend long int trainsform_32bitDataform(long int &data,unsigned char* char_nums);

    long int getnumber_of_images();
    long int getrows();
    long int getcols();
    void trainsform_Dataforms();
    void show_Data();
    bool check_magic_number();
    uchar data[64];
};

#endif // NISTTRAINDATA_H

Nisttraindata.cpp

#include "nisttraindata.h"
#include "trainsformdata.h"
//#include "trainsformdata.h"

NISTTrainData::NISTTrainData()
{
    magic_number = 0;
    number_of_images = 0;
    number_of_rows = 0;
    number_of_columns = 0;
}

NISTTrainData::~NISTTrainData()
{

}

void NISTTrainData::GetROI(Mat& src, Mat& dst)
{
    int left, right, top, bottom;
    left = src.cols;
    right = 0;
    top = src.rows;
    bottom = 0;

    //Get valid area 遍历图像统计区域端点
    for(int i=0; i(i, j) > 0)
            {
                if(jright) right = j;
                if(ibottom) bottom = i;
            }
        }
    }

    Point center;
    center.x = (left + right) / 2;
    center.y = (top + bottom) / 2;

    int width = right - left + 1;
    int height = bottom - top + 1;
    int len = (width < height) ? height : width;

    if(width < height)
    {
        left  = center.x - height*0.5;
        right = center.x + height*0.5;
    }
    else if(width > height)
    {
        top    = center.y - width*0.5;
        bottom = center.y + width*0.5;
    }

//    cout<<"roi len:"<(i,j) = src.at(i+top,j+left);
        }
    resize(dst, dst, Size(image_row,image_col));
}

long int NISTTrainData::getnumber_of_images()
{
    return number_of_images;
}
long int NISTTrainData::getrows()
{
    return number_of_rows;
}
long int NISTTrainData::getcols()
{
    return number_of_columns;
}
void NISTTrainData::trainsform_Dataforms()
{
    trainsform_32bitDataform(magic_number,magicNum);
    trainsform_32bitDataform(number_of_images,ccount);
    trainsform_32bitDataform(number_of_rows,crows);
    trainsform_32bitDataform(number_of_columns,ccols);
}
void NISTTrainData::show_Data()
{
    cout<<" magic_number: "<

Trainformdata.h

#ifndef TRAINSFORMDATA_H
#define TRAINSFORMDATA_H


class trainsformdata
{
public:
    trainsformdata();
    ~trainsformdata();
    long int trainsform_32bitDataform(long int &data,unsigned char* char_nums)
    {
        data+= (((unsigned long int)char_nums[0])<<24);
        data+= (((unsigned long int)char_nums[1])<<16);
        data+= (((unsigned long int)char_nums[2])<<8);
        data+= ((unsigned long int)char_nums[3]);
        return data;
    }
};
#endif // TRAINSFORMDATA_H

Trainsformdata.c

#include "trainsformdata.h"

trainsformdata::trainsformdata()
{

}

trainsformdata::~trainsformdata()
{
}

Mainwindow.cpp

#include "mainwindow.h"
#include "ui_mainwindow.h"



#include 
#include 
#include 
#include 
#include 
#include "qdebug.h"
#include "nisttraindata.h"
#include "nistlabledata.h"

#include 
#include 

using namespace std;
using namespace cv;

#define NTRAINING_SAMPLES   100         // 每类训练样本的数量
#define FRAC_LINEAR_SEP     0.9f        // 线性可分部分的样本组成比例


struct InputData
{

    unsigned char lable;
    float data[NISTTrainData::image_row*NISTTrainData::image_col];
}InputData_;

vector buffer;

void MainWindow::on_pushButton_2_clicked()//载入数据
{
    //Open image and label file


    NISTTrainData TData;
    NISTLableData LData;


    const char fileName[] = "../res/train-images.idx3-ubyte";
    const char labelFileName[] = "../res/train-labels.idx1-ubyte";

    ifstream lab_ifs(labelFileName, ios_base::binary);
    ifstream ifs(fileName, ios_base::binary);

    if( ifs.fail() == true )
    {
        cout<<"train fail"<= total||count==TData.getnumber_of_images())
             break;
        count++;
        ifs.read((char *)(src.data), TData.getcols()*TData.getrows());
        TData.GetROI(src,roi);
        lab_ifs.read((char *)(&(LData.lable)),sizeof(char));

        //imshow("1",roi);

        LData.lable =LData.lable+'0';

        cout<<"lable:"<(i, j);
            }
        }
        buffer.push_back(InputData_);
    }

    cout<<"load trainingdata ok"<& trainData = buffer;
    int testCount = trainData.size();
    int featureLen = NISTTrainData::image_col*NISTTrainData::image_row;
    Mat m = Mat::zeros(1, featureLen, CV_32FC1);
    Mat data = Mat::zeros(testCount, featureLen, CV_32FC1);
    Mat res = Mat::zeros(testCount, 1, CV_32SC1);

    for (int i= 0; i< testCount; i++)
    {
        InputData td = trainData.at(i);
        memcpy(m.data, td.data, featureLen*sizeof(float));
        normalize(m, m);
        memcpy(data.data + i*featureLen*sizeof(float), m.data, featureLen*sizeof(float));

        res.at(i, 0) = td.lable;
    }

//    Mat showm = Mat::zeros(20, 20, CV_32FC1);
//    for(int i =0;i(i,j) = ((InputData)trainData.at(1)).data[i*showm.cols+j];
//        }
//    imshow("sss",showm);

    CvSVM svm = CvSVM();
    CvSVMParams param;
    CvTermCriteria criteria;

    criteria= cvTermCriteria(CV_TERMCRIT_EPS, 1000, FLT_EPSILON);
    param= CvSVMParams(CvSVM::C_SVC, CvSVM::RBF, 10.0, 8.0, 1.0, 10.0, 0.5, 0.1, NULL, criteria);
    //param= CvSVMParams(CvSVM::C_SVC, CvSVM::LINEAR, 10.0, 8.0, 1.0, 10.0, 0.5, 0.1, NULL, criteria);


    cout<<"training..."< Testbuffer;
    NISTTrainData TData;
    NISTLableData LData;


    const char fileName[] = "../res/t10k-images.idx3-ubyte";
    const char labelFileName[] = "../res/t10k-labels.idx1-ubyte";

    ifstream lab_ifs(labelFileName, ios_base::binary);
    ifstream ifs(fileName, ios_base::binary);

    if( ifs.fail() == true )
    {
        cout<<"train fail"<= total||count==TData.getnumber_of_images())
             break;
        count++;
        ifs.read((char *)(src.data), TData.getcols()*TData.getrows());
        TData.GetROI(src,roi);
        lab_ifs.read((char *)(&(LData.lable)),sizeof(char));

        //imshow("1",roi);

        LData.lable =LData.lable+'0';

        cout<<"lable:"<(i, j);
            }
        }
        Testbuffer.push_back(InputData_);
    }

    vector& trainData = Testbuffer;
    int testCount = trainData.size();
    int featureLen = NISTTrainData::image_col*NISTTrainData::image_row;
    Mat m = Mat::zeros(1, featureLen, CV_32FC1);
    cout<<"load trainingdata ok"<

其中一个数据(已经被博主归一化了大小):

导入数据的输出提示:

基于Opencv库中SVM模块的MNIST手写字识别数据库识别_第3张图片

模型训练提示输出:

基于Opencv库中SVM模块的MNIST手写字识别数据库识别_第4张图片

测试集测试结果:线性核下正确率92.83%,低于上面网站上的的正确率,可能是参数没有设置好。

基于Opencv库中SVM模块的MNIST手写字识别数据库识别_第5张图片



你可能感兴趣的:(linux,opencv,图像处理,linux,c++,svm,opencv,机器学习)