决策树的C++实现(CART)

关于决策树的介绍可以参考: https://blog.csdn.net/fengbingchun/article/details/78880934

CART算法的决策树的Python实现可以参考: https://blog.csdn.net/fengbingchun/article/details/78881143

这里参考 https://machinelearningmastery.com/implement-decision-tree-algorithm-scratch-python/ 这篇文章的原有Python实现,使用C++实现了决策树的CART算法,测试数据集是Banknote Dataset,关于Banknote Dataset的介绍可以参考: https://blog.csdn.net/fengbingchun/article/details/78624358 。

decision_tree.hpp文件内容如下:

#ifndef FBC_NN_DECISION_TREE_HPP_
#define FBC_NN_DECISION_TREE_HPP_

#include 
#include 
#include 

namespace ANN {
// referecne: https://machinelearningmastery.com/implement-decision-tree-algorithm-scratch-python/

template
class DecisionTree { // CART(Classification and Regression Trees)
public:
	DecisionTree() = default;
	~DecisionTree() { delete_tree(); }
	int init(const std::vector>& data, const std::vector& classes);
	void set_max_depth(int max_depth) { this->max_depth = max_depth; }
	int get_max_depth() const { return max_depth; }
	void set_min_size(int min_size) { this->min_size = min_size; }
	int get_min_size() const { return min_size; }
	void train();
	int save_model(const char* name) const;
	int load_model(const char* name);
	T predict(const std::vector& data) const;

protected:
	typedef std::tuple>>> dictionary; // index of attribute, value of attribute, groups of data
	typedef std::tuple row_element; // flag, index, value, class_value_left, class_value_right
	typedef struct binary_tree {
		dictionary dict;
		T class_value_left = (T)-1.f;
		T class_value_right = (T)-1.f;
		binary_tree* left = nullptr;
		binary_tree* right = nullptr;
	} binary_tree;

	// Calculate the Gini index for a split dataset
	T gini_index(const std::vector>>& groups, const std::vector& classes) const;
	// Select the best split point for a dataset
	dictionary get_split(const std::vector>& dataset) const;
	// Split a dataset based on an attribute and an attribute value
	std::vector>> test_split(int index, T value, const std::vector>& dataset) const;
	// Create a terminal node value
	T to_terminal(const std::vector>& group) const;
	// Create child splits for a node or make terminal
	void split(binary_tree* node, int depth);
	// Build a decision tree
	void build_tree(const std::vector>& train);
	// Print a decision tree
	void print_tree(const binary_tree* node, int depth = 0) const;
	// Make a prediction with a decision tree
	T predict(binary_tree* node, const std::vector& data) const;
	// calculate accuracy percentage
	double accuracy_metric() const;
	void delete_tree();
	void delete_node(binary_tree* node);
	void write_node(const binary_tree* node, std::ofstream& file) const;
	void node_to_row_element(binary_tree* node, std::vector& rows, int pos) const;
	int height_of_tree(const binary_tree* node) const;
	void row_element_to_node(binary_tree* node, const std::vector& rows, int n, int pos);

private:
	std::vector> src_data;
	binary_tree* tree = nullptr;
	int samples_num = 0;
	int feature_length = 0;
	int classes_num = 0;
	int max_depth = 10; // maximum tree depth
	int min_size = 10; // minimum node records
	int max_nodes = -1;
};

} // namespace ANN


#endif // FBC_NN_DECISION_TREE_HPP_

decision_tree.cpp文件内容如下:

#include "decision_tree.hpp"
#include 
#include 
#include 
#include 
#include "common.hpp"

namespace ANN {

template
int DecisionTree::init(const std::vector>& data, const std::vector& classes)
{
	CHECK(data.size() != 0 && classes.size() != 0 && data[0].size() != 0);

	this->samples_num = data.size();
	this->classes_num = classes.size();
	this->feature_length = data[0].size() -1;

	for (int i = 0; i < this->samples_num; ++i) {
		this->src_data.emplace_back(data[i]);
	}

	return 0;
}

template
T DecisionTree::gini_index(const std::vector>>& groups, const std::vector& classes) const
{
	// Gini calculation for a group
	// proportion = count(class_value) / count(rows)
	// gini_index = (1.0 - sum(proportion * proportion)) * (group_size/total_samples)

	// count all samples at split point
	int instances = 0;
	int group_num = groups.size();
	for (int i = 0; i < group_num; ++i) {
		instances += groups[i].size();
	}

	// sum weighted Gini index for each group
	T gini = (T)0.;
	for (int i = 0; i < group_num; ++i) {
		int size = groups[i].size();
		// avoid divide by zero
		if (size == 0) continue;
		T score = (T)0.;

		// score the group based on the score for each class
		T p = (T)0.;
		for (int c = 0; c < classes.size(); ++c) {
			int count = 0;
			for (int t = 0; t < size; ++t) {
				if (groups[i][t][this->feature_length] == classes[c]) ++count;
			}
			T p = (float)count / size;
			score += p * p;
		}

		// weight the group score by its relative size
		gini += (1. - score) * (float)size / instances;
	}

	return gini;
}

template
std::vector>> DecisionTree::test_split(int index, T value, const std::vector>& dataset) const
{
	std::vector>> groups(2); // 0: left, 1: reight

	for (int row = 0; row < dataset.size(); ++row) {
		if (dataset[row][index] < value) {
			groups[0].emplace_back(dataset[row]);
		} else {
			groups[1].emplace_back(dataset[row]);
		}
	}

	return groups;
}

template
std::tuple>>> DecisionTree::get_split(const std::vector>& dataset) const
{
	std::vector values;
	for (int i = 0; i < dataset.size(); ++i) {
		values.emplace_back(dataset[i][this->feature_length]);
	}

	std::set vals(values.cbegin(), values.cend());
	std::vector class_values(vals.cbegin(), vals.cend());

	int b_index = 999;
	T b_value = (T)999.;
	T b_score = (T)999.;
	std::vector>> b_groups(2);

	for (int index = 0; index < this->feature_length; ++index) {
		for (int row = 0; row < dataset.size(); ++row) {
			std::vector>> groups = test_split(index, dataset[row][index], dataset);
			T gini = gini_index(groups, class_values);

			if (gini < b_score) {
				b_index = index;
				b_value = dataset[row][index];
				b_score = gini;
				b_groups = groups;
			}
		}
	}

	// a new node: the index of the chosen attribute, the value of that attribute by which to split and the two groups of data split by the chosen split point
	return std::make_tuple(b_index, b_value, b_groups);
}

template
T DecisionTree::to_terminal(const std::vector>& group) const
{
	std::vector values;
	for (int i = 0; i < group.size(); ++i) {
		values.emplace_back(group[i][this->feature_length]);
	}

	std::set vals(values.cbegin(), values.cend());
	int max_count = -1, index = -1;
	for (int i = 0; i < vals.size(); ++i) {
		int count = std::count(values.cbegin(), values.cend(), *std::next(vals.cbegin(), i));
		if (max_count < count) {
			max_count = count;
			index = i;
		}
	}

	return *std::next(vals.cbegin(), index);
}

template
void DecisionTree::split(binary_tree* node, int depth)
{
	std::vector> left = std::get<2>(node->dict)[0];
	std::vector> right = std::get<2>(node->dict)[1];
	std::get<2>(node->dict).clear();

	// check for a no split
	if (left.size() == 0 || right.size() == 0) {
		for (int i = 0; i < right.size(); ++i) {
			left.emplace_back(right[i]);
		}

		node->class_value_left = node->class_value_right = to_terminal(left);
		return;
	}

	// check for max depth
	if (depth >= max_depth) {
		node->class_value_left = to_terminal(left);
		node->class_value_right = to_terminal(right);
		return;
	}

	// process left child
	if (left.size() <= min_size) {
		node->class_value_left = to_terminal(left);
	} else {
		dictionary dict = get_split(left);
		node->left = new binary_tree;
		node->left->dict = dict;
		split(node->left, depth+1);
	}

	// process right child
	if (right.size() <= min_size) {
		node->class_value_right = to_terminal(right);
	} else {
		dictionary dict = get_split(right);
		node->right = new binary_tree;
		node->right->dict = dict;
		split(node->right, depth+1);
	}
}

template
void DecisionTree::build_tree(const std::vector>& train)
{
	// create root node
	dictionary root = get_split(train);
	binary_tree* node = new binary_tree;
	node->dict = root;
	tree = node;
	split(node, 1);
}

template
void DecisionTree::train()
{
	this->max_nodes = (1 << max_depth) - 1;
	build_tree(src_data);

	accuracy_metric();
	
	//binary_tree* tmp = tree;
	//print_tree(tmp);
}

template
T DecisionTree::predict(const std::vector& data) const
{
	if (!tree) {
		fprintf(stderr, "Error, tree is null\n");
		return -1111.f;
	}

	return predict(tree, data);
}

template
T DecisionTree::predict(binary_tree* node, const std::vector& data) const
{
	if (data[std::get<0>(node->dict)] < std::get<1>(node->dict)) {
		if (node->left) {
			return predict(node->left, data);
		} else {
			return node->class_value_left;
		}
	} else {
		if (node->right) {
			return predict(node->right, data);
		} else {
			return node->class_value_right;
		}
	}
}

template
int DecisionTree::save_model(const char* name) const
{
	std::ofstream file(name, std::ios::out);
	if (!file.is_open()) {
		fprintf(stderr, "open file fail: %s\n", name);
		return -1;
	}

	file<
void DecisionTree::write_node(const binary_tree* node, std::ofstream& file) const
{
	/*if (!node) return;

	write_node(node->left, file);
	file<(node->dict)<<","<(node->dict)<<","<class_value_left<<","<class_value_right<right, file);*/
	
	//typedef std::tuple row; // flag, index, value, class_value_left, class_value_right
	std::vector vec(this->max_nodes, std::make_tuple(-1, -1, (T)-1.f, (T)-1.f, (T)-1.f));

	binary_tree* tmp = const_cast(node);
	node_to_row_element(tmp, vec, 0);

	for (const auto& row : vec) {
		file<(row)<<","<(row)<<","<(row)<<","<(row)<<","<(row)<
void DecisionTree::node_to_row_element(binary_tree* node, std::vector& rows, int pos) const
{
	if (!node) return;

	rows[pos] = std::make_tuple(0, std::get<0>(node->dict), std::get<1>(node->dict), node->class_value_left, node->class_value_right); // 0: have node, -1: no node
	
	if (node->left) node_to_row_element(node->left, rows, 2*pos+1);
	if (node->right) node_to_row_element(node->right, rows, 2*pos+2);
}

template
int DecisionTree::height_of_tree(const binary_tree* node) const
{
	if (!node)
		return 0;
	else
		return std::max(height_of_tree(node->left), height_of_tree(node->right)) + 1;
}

template
int DecisionTree::load_model(const char* name)
{
	std::ifstream file(name, std::ios::in);
	if (!file.is_open()) {
		fprintf(stderr, "open file fail: %s\n", name);
		return -1;
	}

	std::string line, cell;
	std::getline(file, line);
	std::stringstream line_stream(line);
	std::vector vec;
	int count = 0;
	while (std::getline(line_stream, cell, ',')) {
		vec.emplace_back(std::stoi(cell));
	}
	CHECK(vec.size() == 2);
	max_depth = vec[0];
	min_size = vec[1];
	max_nodes = (1 << max_depth) - 1;
	std::vector rows(max_nodes);
	
	if (typeid(float).name() == typeid(T).name()) {
		while (std::getline(file, line)) {
			std::stringstream line_stream2(line);
			std::vector vec2;
		
			while(std::getline(line_stream2, cell, ',')) {
				vec2.emplace_back(std::stof(cell));
			}
			
			CHECK(vec2.size() == 5);
			rows[count] = std::make_tuple((int)vec2[0], (int)vec2[1], vec2[2], vec2[3], vec2[4]);
			//fprintf(stderr, "%d, %d, %f, %f, %f\n", std::get<0>(rows[count]), std::get<1>(rows[count]), std::get<2>(rows[count]), std::get<3>(rows[count]), std::get<4>(rows[count]));
			++count;
		}
	} else { // double
		while (std::getline(file, line)) {
			std::stringstream line_stream2(line);
			std::vector vec2;
		
			while(std::getline(line_stream2, cell, ',')) {
				vec2.emplace_back(std::stod(cell));
			}

			CHECK(vec2.size() == 5);
			rows[count] = std::make_tuple((int)vec2[0], (int)vec2[1], vec2[2], vec2[3], vec[4]);
			++count;
		}
	}

	CHECK(max_nodes == count);
	CHECK(std::get<0>(rows[0]) != -1);

	binary_tree* tmp = new binary_tree;
	std::vector>> dump;
	tmp->dict = std::make_tuple(std::get<1>(rows[0]), std::get<2>(rows[0]), dump);
	tmp->class_value_left = std::get<3>(rows[0]);
	tmp->class_value_right = std::get<4>(rows[0]);
	tree = tmp;
	row_element_to_node(tmp, rows, max_nodes, 0);

	file.close();
	return 0;
}

template
void DecisionTree::row_element_to_node(binary_tree* node, const std::vector& rows, int n, int pos)
{
	if (!node || n == 0) return;

	int new_pos = 2 * pos + 1;
	if (new_pos < n && std::get<0>(rows[new_pos]) != -1) {
		node->left = new binary_tree;
		std::vector>> dump;
		node->left->dict = std::make_tuple(std::get<1>(rows[new_pos]), std::get<2>(rows[new_pos]), dump);
		node->left->class_value_left = std::get<3>(rows[new_pos]);
		node->left->class_value_right = std::get<4>(rows[new_pos]);

		row_element_to_node(node->left, rows, n, new_pos);
	}

	new_pos = 2 * pos + 2;
	if (new_pos < n && std::get<0>(rows[new_pos]) != -1) {
		node->right = new binary_tree;
		std::vector>> dump;
		node->right->dict = std::make_tuple(std::get<1>(rows[new_pos]), std::get<2>(rows[new_pos]), dump);
		node->right->class_value_left = std::get<3>(rows[new_pos]);
		node->right->class_value_right = std::get<4>(rows[new_pos]);
	
		row_element_to_node(node->right, rows, n, new_pos);
	}
}

template
void DecisionTree::delete_tree()
{
	delete_node(tree);
}

template
void DecisionTree::delete_node(binary_tree* node)
{
	if (node->left) delete_node(node->left);
	if (node->right) delete_node(node->right);
	delete node;
}

template
double DecisionTree::accuracy_metric() const
{
	int correct = 0;
	for (int i = 0; i < this->samples_num; ++i) {
		T predicted = predict(tree, src_data[i]);
		if (predicted == src_data[i][this->feature_length])
			++correct;
	}

	double accuracy = correct / (double)samples_num * 100.;
	fprintf(stdout, "train accuracy: %f\n", accuracy);

	return accuracy;  
}

template
void DecisionTree::print_tree(const binary_tree* node, int depth) const
{
	if (node) {
		std::string blank = " ";
		for (int i = 0; i < depth; ++i) blank += blank;
		fprintf(stdout, "%s[X%d < %.3f]\n", blank.c_str(), std::get<0>(node->dict)+1, std::get<1>(node->dict));

		if (!node->left || !node->right)
			blank += blank;

		if (!node->left)
			fprintf(stdout, "%s[%.1f]\n", blank.c_str(), node->class_value_left);
		else 
			print_tree(node->left, depth+1);

		if (!node->right)
			fprintf(stdout, "%s[%.1f]\n", blank.c_str(), node->class_value_right);
		else
			print_tree(node->right, depth+1);
			
	}
}

template class DecisionTree;
template class DecisionTree;

} // namespace ANN

对外提供两个接口,一个是test_decision_tree_train用于训练,一个是test_decision_tree_predict用于测试,其code如下:

// =============================== decision tree ==============================
int test_decision_tree_train()
{
	// small dataset test
	/*const std::vector> data{ { 2.771244718f, 1.784783929f, 0.f },
					{ 1.728571309f, 1.169761413f, 0.f },
					{ 3.678319846f, 2.81281357f, 0.f },
					{ 3.961043357f, 2.61995032f, 0.f },
					{ 2.999208922f, 2.209014212f, 0.f },
					{ 7.497545867f, 3.162953546f, 1.f },
					{ 9.00220326f, 3.339047188f, 1.f },
					{ 7.444542326f, 0.476683375f, 1.f },
					{ 10.12493903f, 3.234550982f, 1.f },
					{ 6.642287351f, 3.319983761f, 1.f } };

	const std::vector classes{ 0.f, 1.f };

	ANN::DecisionTree dt;
	dt.init(data, classes);
	dt.set_max_depth(3);
	dt.set_min_size(1);

	dt.train();
#ifdef _MSC_VER
	const char* model_name = "E:/GitCode/NN_Test/data/decision_tree.model";
#else
	const char* model_name = "data/decision_tree.model";
#endif
	dt.save_model(model_name);

	ANN::DecisionTree dt2;
	dt2.load_model(model_name);
	const std::vector> test{{0.6f, 1.9f, 0.f}, {9.7f, 4.3f, 1.f}};
	for (const auto& row : test) {
		float ret = dt2.predict(row);
		fprintf(stdout, "predict result: %.1f, actural value: %.1f\n", ret, row[2]);
	} */

	// banknote authentication dataset
#ifdef _MSC_VER
	const char* file_name = "E:/GitCode/NN_Test/data/database/BacknoteDataset/data_banknote_authentication.txt";
#else
	const char* file_name = "data/database/BacknoteDataset/data_banknote_authentication.txt";
#endif

	std::vector> data;
	int ret = read_txt_file(file_name, data, ',', 1372, 5);
	if (ret != 0) {
		fprintf(stderr, "parse txt file fail: %s\n", file_name);
		return -1;
	}

	//fprintf(stdout, "data size: rows: %d\n", data.size());

	const std::vector classes{ 0.f, 1.f };
	ANN::DecisionTree dt;
	dt.init(data, classes);
	dt.set_max_depth(6);
	dt.set_min_size(10);
	dt.train();
#ifdef _MSC_VER
	const char* model_name = "E:/GitCode/NN_Test/data/decision_tree.model";
#else
	const char* model_name = "data/decision_tree.model";
#endif
	dt.save_model(model_name);

	return 0;
}

int test_decision_tree_predict()
{
#ifdef _MSC_VER
	const char* model_name = "E:/GitCode/NN_Test/data/decision_tree.model";
#else
	const char* model_name = "data/decision_tree.model";
#endif
	ANN::DecisionTree dt;
	dt.load_model(model_name);
	int max_depth = dt.get_max_depth();
	int min_size = dt.get_min_size();
	fprintf(stdout, "max_depth: %d, min_size: %d\n", max_depth, min_size);

	std::vector> test {{-2.5526,-7.3625,6.9255,-0.66811,1},
				       {-4.5531,-12.5854,15.4417,-1.4983,1},
				       {4.0948,-2.9674,2.3689,0.75429,0},
				       {-1.0401,9.3987,0.85998,-5.3336,0},
				       {1.0637,3.6957,-4.1594,-1.9379,1}};
	for (const auto& row : test) {	
		float ret = dt.predict(row);
		fprintf(stdout, "predict result: %.1f, actual value: %.1f\n", ret, row[4]);
	}

	return 0;
}

训练接口执行结果如下:

测试接口执行结果如下:

决策树的C++实现(CART)_第1张图片

训练时生成的模型decison_tree.model内容如下:

6,10
0,0,0.3223,-1,-1
0,1,7.6274,-1,-1
0,2,-4.3839,-1,-1
0,0,-0.39816,-1,-1
0,0,-4.2859,-1,-1
0,0,4.2164,-1,0
0,0,1.594,-1,-1
0,2,6.2204,-1,-1
0,1,5.8974,-1,-1
0,0,-5.4901,-1,1
0,0,-1.5768,-1,-1
0,0,0.47368,1,-1
-1,-1,-1,-1,-1
0,2,-2.2718,-1,-1
0,0,2.0421,-1,-1
0,1,7.3273,-1,1
0,1,-4.6062,-1,-1
0,2,3.1143,-1,-1
0,0,0.049175,0,0
0,0,-6.2003,1,1
-1,-1,-1,-1,-1
0,0,-2.7419,0,-1
0,0,-1.5768,0,0
-1,-1,-1,-1,-1
0,0,0.47368,1,1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
0,1,7.6377,-1,0
0,3,0.097399,-1,-1
0,2,-2.3386,1,-1
0,0,3.6216,-1,-1
0,0,-1.3971,1,1
-1,-1,-1,-1,-1
0,0,-1.6677,1,1
0,0,-1.7781,0,0
0,0,-0.36506,1,1
0,3,1.547,0,1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
0,0,-2.7419,0,0
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
0,0,1.0552,1,1
-1,-1,-1,-1,-1
0,0,0.4339,0,0
0,2,2.0013,1,0
-1,-1,-1,-1,-1
0,0,1.8993,0,0
0,0,3.4566,0,0
0,0,3.6216,0,0

GitHub: https://github.com/fengbingchun/NN_Test 

你可能感兴趣的:(Deep,Learning)