贝叶斯分类器库源码:
/****************************************************************************************\
* Normal Bayes Classifier *
\****************************************************************************************/
/* The structure, representing the grid range of statmodel parameters.
It is used for optimizing statmodel accuracy by varying model parameters,
the accuracy estimate being computed by cross-validation.
The grid is logarithmic, so <step> must be greater then 1. */
class CvMLData;
struct CV_EXPORTS CvParamGrid
{
// SVM params type
enum { SVM_C=0, SVM_GAMMA=1, SVM_P=2, SVM_NU=3, SVM_COEF=4, SVM_DEGREE=5 };
CvParamGrid()
{
min_val = max_val = step = 0;
}
CvParamGrid( double _min_val, double _max_val, double log_step )
{
min_val = _min_val;
max_val = _max_val;
step = log_step;
}
//CvParamGrid( int param_id );
bool check() const;
double min_val;
double max_val;
double step;
};
class CV_EXPORTS CvNormalBayesClassifier : public CvStatModel
{
public:
CvNormalBayesClassifier();
virtual ~CvNormalBayesClassifier();
CvNormalBayesClassifier( const CvMat* _train_data, const CvMat* _responses,
const CvMat* _var_idx=0, const CvMat* _sample_idx=0 );
virtual bool train( const CvMat* _train_data, const CvMat* _responses,
const CvMat* _var_idx = 0, const CvMat* _sample_idx=0, bool update=false );
virtual float predict( const CvMat* _samples, CvMat* results=0 ) const;
virtual void clear();
#ifndef SWIG
CvNormalBayesClassifier( const cv::Mat& _train_data, const cv::Mat& _responses,
const cv::Mat& _var_idx=cv::Mat(), const cv::Mat& _sample_idx=cv::Mat() );
virtual bool train( const cv::Mat& _train_data, const cv::Mat& _responses,
const cv::Mat& _var_idx = cv::Mat(), const cv::Mat& _sample_idx=cv::Mat(),
bool update=false );
virtual float predict( const cv::Mat& _samples, cv::Mat* results=0 ) const;
#endif
virtual void write( CvFileStorage* storage, const char* name ) const;
virtual void read( CvFileStorage* storage, CvFileNode* node );
protected:
int var_count, var_all;
CvMat* var_idx;
CvMat* cls_labels;
CvMat** count;
CvMat** sum;
CvMat** productsum;
CvMat** avg;
CvMat** inv_eigen_values;
CvMat** cov_rotate_mats;
CvMat* c;
};
测试源码:
//源码引用自:http://blog.csdn.net/carson2005/article/details/6854024#
#include "stdafx.h"
#include <ml.h>
#include <iostream>
#include <highgui.h>
#include <cv.h>
#include <cxcore.h>
using namespace cv;
using namespace std;
//10个样本特征向量维数为12的训练样本集,第一列为该样本的类别标签
double inputArr[10][13] =
{
1,0.708333,1,1,-0.320755,-0.105023,-1,1,-0.419847,-1,-0.225806,0,1,
-1,0.583333,-1,0.333333,-0.603774,1,-1,1,0.358779,-1,-0.483871,0,-1,
1,0.166667,1,-0.333333,-0.433962,-0.383562,-1,-1,0.0687023,-1,-0.903226,-1,-1,
-1,0.458333,1,1,-0.358491,-0.374429,-1,-1,-0.480916,1,-0.935484,0,-0.333333,
-1,0.875,-1,-0.333333,-0.509434,-0.347032,-1,1,-0.236641,1,-0.935484,-1,-0.333333,
-1,0.5,1,1,-0.509434,-0.767123,-1,-1,0.0534351,-1,-0.870968,-1,-1,
1,0.125,1,0.333333,-0.320755,-0.406393,1,1,0.0839695,1,-0.806452,0,-0.333333,
1,0.25,1,1,-0.698113,-0.484018,-1,1,0.0839695,1,-0.612903,0,-0.333333,
1,0.291667,1,1,-0.132075,-0.237443,-1,1,0.51145,-1,-0.612903,0,0.333333,
1,0.416667,-1,1,0.0566038,0.283105,-1,1,0.267176,-1,0.290323,0,1
};
//一个测试样本的特征向量
double testArr[]=
{
0.25,1,1,-0.226415,-0.506849,-1,-1,0.374046,-1,-0.83871,0,-1
};
int _tmain(int argc, _TCHAR* argv[])
{
Mat trainData(10, 12, CV_32FC1);//构建训练样本的特征向量
for (int i=0; i<10; i++)
{
for (int j=0; j<12; j++)
{
trainData.at<float>(i, j) = inputArr[i][j+1];
}
}
Mat trainResponse(10, 1, CV_32FC1);//构建训练样本的类别标签
for (int i=0; i<10; i++)
{
trainResponse.at<float>(i, 0) = inputArr[i][0];
}
CvNormalBayesClassifier nbc;
bool trainFlag = nbc.train(trainData, trainResponse);//进行贝叶斯分类器训练
if (trainFlag)
{
cout<<"train over..."<<endl;
nbc.save("c:/normalBayes.txt");
}
else
{
cout<<"train error..."<<endl;
system("pause");
exit(-1);
}
CvNormalBayesClassifier testNbc;
testNbc.load("c:/normalBayes.txt");
Mat testSample(1, 12, CV_32FC1);//构建测试样本
for (int i=0; i<12; i++)
{
testSample.at<float>(0, i) = testArr[i];
}
float flag = testNbc.predict(testSample);//进行测试
cout<<"flag = "<<flag<<endl;
system("pause");
return 0;
}
http://blog.csdn.net/carson2005/article/details/6854024
http://blog.csdn.net/godenlove007/article/details/8913007