opencv svm 多分类问题

svm一般都是两分类的问题,有时候我们需要多分类的时候 以下代码就派上用场了。

话不多说,直接上代码。

// svm_test.cpp : 定义控制台应用程序的入口点。
//

#include "stdafx.h"
#include 
#include 
#include 
#include 
#include 
#include 

using namespace cv;
using namespace std;

#define skyface_API	extern __declspec(dllexport)
skyface_API int sex_detect(vector &feats, const char* modpth);

Mat traindata(string path, int num)
{
	vector> data(num, vector(2048, 0));
	ifstream ifs;
	ifs.open(path);
	for (int i = 0; i < num; i++)
	{
		for (int j = 0; j < 2048; j++)
		{
			ifs >> data[i][j];
		}
	}
	ifs.close();
	Mat class_n_data(data.size(), data.at(0).size(), CV_32FC1);
	for (int i = 0; i < data.size(); i++)
		for (int j = 0; j < data.at(0).size(); j++)
			class_n_data.at(i, j) = data.at(i).at(j);
	return class_n_data;
}

Mat get_traindata3(Mat class1, Mat class2, Mat class3)
{
	Mat traindata(class1.rows + class2.rows + class3.rows , 2048, CV_32FC1);
	Mat tmp = traindata.rowRange(0, class1.rows);
	class1.copyTo(tmp);
	tmp = traindata.rowRange(class1.rows, class1.rows + class2.rows);
	class2.copyTo(tmp);
	tmp = traindata.rowRange(class1.rows + class2.rows, class1.rows + class2.rows + class3.rows);
	class3.copyTo(tmp);
	cout << "获取到训练数据!" << endl;
	return traindata;
}

Mat get_labels3(Mat class1, Mat class2, Mat class3)
{
	Mat labels(class1.rows + class2.rows + class3.rows , 1, CV_32FC1);
	labels.rowRange(0, class1.rows).setTo(1);
	labels.rowRange(class1.rows, class1.rows + class2.rows).setTo(2);
	labels.rowRange(class1.rows + class2.rows, class1.rows + class2.rows + class3.rows).setTo(3);
	return labels;
}


void trainSVM(Mat traindata, Mat labels, string modelpth)
{
	//------------------------ 2. Set up the support vector machines parameters --------------------
	CvSVMParams params;
	params.svm_type = SVM::C_SVC;
	params.C = 0.1;
	params.kernel_type = SVM::LINEAR;
	params.term_crit = TermCriteria(CV_TERMCRIT_ITER, (int)1e7, 1e-6);
	//------------------------ 3. Train the svm ----------------------------------------------------
	cout << "Starting training process" << endl;
	CvSVM svm;
	svm.train(traindata, labels, Mat(), Mat(), params);
	cout << "Finished training process" << endl;
	svm.save("../data/model_AGE.txt");
}

int sex_detect(vector &feats, const char* modpth)
{
	CvSVM SVM;
	SVM.load(modpth);
	int i;
	float* testdata = new float[2048];
	for (int i = 0; i < 2048; i++)
	{
		testdata[i] = feats[i];
	}
	Mat test = Mat(1, 2048, CV_32FC1, testdata);
	float result = SVM.predict(test);
	delete[] testdata;
	return result;
}
int main()
{

  //int labels[3]=[class1,class2,class3];
  
	Mat class1 = traindata("../data/feats_left.txt",40);
	Mat class2 = traindata("../data/feats_right.txt",36);
	Mat class3 = traindata("../data/feats_pos.txt",48);
	
	//Mat traindata = get_traindata(class1, class2);
	//Mat labels = get_labels(class1, class2);
	
	Mat traindata = get_traindata3(class1, class2, class3);
	Mat labels = get_labels3(class1, class2, class3);
	
	trainSVM(traindata, labels, "*");
	CvSVM SVM;
	SVM.load("../data/model_AGE.txt");
	ifstream ifs;
	float testdata[2048];
	ifs.open("../data/feats_test.txt");
		for (int i = 0; i < 2048; i++)
		{
			ifs >> testdata[i];
		}
		Mat test = Mat(1, 2048, CV_32FC1, testdata);
		float result = SVM.predict(test);
		if (result == 1)
			cout << "左偏30度" << endl;
		else if (result == 2)
			cout<< "右偏30度" <


你可能感兴趣的:(opencv svm 多分类问题)