Opencv 的训练集和测试写入通过format(Mat , Formatter::FMT_CSV),方式完成;
读取是通过TrainData::loadFromCSV完成。
具体代码实现如下:
/* * filestorage_sample demonstrate the usage of the opencv serialization functionality */
#include "opencv2/core.hpp" #include #include "opencv2/imgcodecs.hpp" #include #include #include #include #include
using std::string; using std::cout; using std::endl; using std::cerr; using std::ostream; using namespace cv; using namespace cv::ml; using namespace std;
#define IMG_WIDTH 128 #define IMG_HEIGHT 128 #define IMG_CHANNEL 1
int main(int ac, char** av) { string filename ; string path; if(ac > 2 ) { filename = av[2]; path = av[1]; } else { printf(" should be ./*.out path.txt outputfilename\n"); } printf("%s(%d):%s\n", __FUNCTION__, __LINE__, filename.c_str());
if (filename.empty()) { printf("Error:fileName NULL\n"); return 1; }
//write {
ifstream fin; string tmp; fin.open(path.c_str()); // FILE* fpt = fopen(filename.c_str(),(const char*)"w"); ofstream file(filename.c_str());
cout << "writing images\n"; int i = 0; vector
while(fin >> tmp) { #if 0 stringstream tmpStr; tmpStr<<"image"< cout < i++; #endif FileVec.push_back(tmp); } vector vector Mat MergeOut; int Cnt = 0; for(strIt = FileVec.begin();strIt != FileVec.end(); strIt++, Cnt++ ) { Mat inputMat = imread(*strIt, 0); int rows = inputMat.rows; int cols = inputMat.cols; int channels = inputMat.channels(); vector imgVect.resize(3); printf("%d,%d,%d\n", inputMat.rows, inputMat.cols, inputMat.channels()); if(rows != IMG_WIDTH || cols != IMG_HEIGHT || channels != IMG_CHANNEL ) { printf("Error %s: rows should be %d, cols should be %d, channels should be %d\n", strIt->c_str(), IMG_WIDTH, IMG_HEIGHT, IMG_CHANNEL); return 0; } #if 0 split(inputMat, imgVect); unsigned char *pData = inputMat.data; ofstream file("test.csv"); file << format(imgVect[0], Formatter::FMT_CSV); file.close(); #endif Mat imgResize; resize(inputMat, imgResize, Size(64,64), 64, 64, INTER_LINEAR); // Mat ImgForTrain(1, inputMat.rows*inputMat.cols, CV_8UC1, (void*)inputMat.data); Mat ImgForTrain(1, imgResize.rows*imgResize.cols, CV_8UC1, (void*)imgResize.data); AllImage.push_back(ImgForTrain); if(Cnt == 0) MergeOut = ImgForTrain; else vconcat(ImgForTrain, MergeOut, MergeOut); //file << format(ImgForTrain, Formatter::FMT_CSV); printf("%d,%d,%d\n", MergeOut.rows, MergeOut.cols, MergeOut.channels()); } // cout < printf("%d,%d,%d\n", MergeOut.rows, MergeOut.cols, MergeOut.channels()); Mat ImgLabel(MergeOut.rows, 1, CV_8UC1); for(int i = 0; i < MergeOut.rows; i++) { if(i < MergeOut.rows/2) { ImgLabel.at } else ImgLabel.at } hconcat( MergeOut, ImgLabel, MergeOut); file << format(MergeOut, Formatter::FMT_CSV); // fclose(fpt); fin.close(); file.close(); } //read { cout <<"\n"< Ptr train_data = TrainData::loadFromCSV(filename.c_str(), 0); Mat m = train_data->getTrainSamples(); //cout << format(m, Formatter::FMT_CSV) << endl; printf("\n%d,%d,%d\n", m.rows, m.cols, m.channels()); // cout < // cout << m.row(0) << endl; Mat label = train_data->getTrainResponses(); printf("\n%d,%d,%d\n", label.rows, label.cols, label.channels()); } return 0; } |
该代码已验证。
通过它,我生成了用于64x64分辨率的500张戴眼镜和500张不戴眼镜的训练集和100张戴眼镜和100张不戴眼镜的验证集。
其中,第一个入参是输入的图像列表文件;第二个入参是生成的数据集的文件名。