参考文献:http://www.zilhua.com/629.html
http://www.tuicool.com/articles/JvMJve
http://blog.sina.com.cn/s/blog_573085f70101ivj5.html
我的数据挖掘算法:https://github.com/linyiqun/DataMiningAlgorithm
我的算法库:https://github.com/linyiqun/lyq-algorithms-lib
前言
提到森林,就不得不联想到树,因为正是一棵棵的树构成了庞大的森林,而在本篇文章中的”树“,指的就是Decision Tree-----决策树。随机森林就是一棵棵决策树的组合,也就是说随机森林=boosting+决策树,这样就好理解多了吧,再来说说GBDT,GBDT全称是Gradient Boosting Decision Tree,就是梯度提升决策树,与随机森林的思想很像,但是比随机森林稍稍的难一点,当然效果相对于前者而言,也会好许多。由于本人才疏学浅,本文只会详细讲述Random Forest算法的部分,至于GBDT我会给出一小段篇幅做介绍引导,读者能够如果有兴趣的话,可以自行学习。
随机森林算法
决策树
要想理解随机森林算法,就不得不提决策树,什么是决策树,如何构造决策树,简单的回答就是数据的分类以树形结构的方式所展现,每个子分支都代表着不同的分类情况,比如下面的这个图所示:

当然决策树的每个节点分支不一定是三元的,可以有2个或者更多。分类的终止条件为,没有可以再拿来分类的属性条件或者说分到的数据的分类已经完全一致的情况。决策树分类的标准和依据是什么呢,下面介绍主要的2种划分标准。
1、信息增益。这是ID3算法系列所用的方法,C4.5算法在这上面做了少许的改进,用信息增益率来作为划分的标准,可以稍稍减小数据过于拟合的缺点。
2、基尼指数。这是CART分类回归树所用的方法。也是类似于信息增益的一个定义,最终都是根据数据划分后的纯度来做比较,这个纯度,你也可以理解为熵的变化,当然我们所希望的情况就是分类后数据的纯度更纯,也就是说,前后划分分类之后的熵的差越大越好。不过CART算法比较好的一点是树构造好后,还有剪枝的操作,剪枝操作的种类就比较多了,我之前在实现CART算法时用的是代价复杂度的剪枝方法。
这2种决策算法在我之前的博文中已经有所提及,不理解的可以点击我的ID3系列算法介绍和我的CART分类回归树算法。
Boosting
原本不打算将Boosting单独拉出来讲的,后来想想还是有很多内容可谈的。Boosting本身不是一种算法,他更应该说是一种思想,首先对数据构造n个弱分类器,最后通过组合n个弱分类器对于某个数据的判断结果作为最终的分类结果,就变成了一个强分类器,效果自然要好过单一分类器的分类效果。他可以理解为是一种提升算法,举一个比较常见的Boosting思想的算法AdaBoost,他在训练每个弱分类器的时候,提高了对于之前分错数据的权重值,最终能够组成一批相互互补的分类器集合。详细可以查看我的AdaBoost算法学习。
OK,2个重要的概念都已经介绍完毕,终于可以介绍主角Random Forest的出现了,正如前言中所说Random Forest=Decision Trees + Boosting,这里的每个弱分类器就是一个决策树了,不过这里的决策树都是二叉树,就是只有2个孩子分支,自然我立刻想到的做法就是用CART算法来构建,因为人家算法就是二元分支的。随机算法,随机算法,当然重在随机2个字上面,下面是2个方面体现了随机性。对于数据样本的采集量,比如我数据由100条,我可以每次随机取出其中的20条,作为我构造决策树的源数据,采取又放回的方式,并不是第一次抽到的数据,第二次不能重复,第二随机性体现在对于数据属性的随机采集,比如一行数据总共有10个特征属性,我每次随机采用其中的4个。正是由于对于数据的行压缩和列压缩,使得数据的随机性得以保证,就很难出现之前的数据过拟合的问题了,也就不需要在决策树最后进行剪枝操作了,这个是与一般的CART算法所不同的,尤其需要注意。
下面是随机森林算法的构造过程:
1、通过给定的原始数据,选出其中部分数据进行决策树的构造,数据选取是”有放回“的过程,我在这里用的是CART分类回归树。
2、随机森林构造完成之后,给定一组测试数据,使得每个分类器对其结果分类进行评估,最后取评估结果的众数最为最终结果。
算法非常的好理解,在Boosting算法和决策树之上做了一个集成,下面给出算法的实现,很多资料上只有大篇幅的理论,我还是希望能带给大家一点实在的东西。
随机算法的实现
输入数据(之前决策树算法时用过的)input.txt:
- Rid Age Income Student CreditRating BuysComputer
- 1 Youth High No Fair No
- 2 Youth High No Excellent No
- 3 MiddleAged High No Fair Yes
- 4 Senior Medium No Fair Yes
- 5 Senior Low Yes Fair Yes
- 6 Senior Low Yes Excellent No
- 7 MiddleAged Low Yes Excellent Yes
- 8 Youth Medium No Fair No
- 9 Youth Low Yes Fair Yes
- 10 Senior Medium Yes Fair Yes
- 11 Youth Medium Yes Excellent Yes
- 12 MiddleAged Medium No Excellent Yes
- 13 MiddleAged High Yes Fair Yes
- 14 Senior Medium No Excellent No
树节点类TreeNode.java:
- package DataMining_RandomForest;
-
- import java.util.ArrayList;
-
-
-
-
-
-
-
- public class TreeNode {
-
- private String attrName;
-
- private int nodeIndex;
-
- private int leafNum;
-
- private double alpha;
-
- private String parentAttrValue;
-
- private TreeNode[] childAttrNode;
-
- private ArrayList<String> dataIndex;
-
- public String getAttrName() {
- return attrName;
- }
-
- public void setAttrName(String attrName) {
- this.attrName = attrName;
- }
-
- public int getNodeIndex() {
- return nodeIndex;
- }
-
- public void setNodeIndex(int nodeIndex) {
- this.nodeIndex = nodeIndex;
- }
-
- public double getAlpha() {
- return alpha;
- }
-
- public void setAlpha(double alpha) {
- this.alpha = alpha;
- }
-
- public String getParentAttrValue() {
- return parentAttrValue;
- }
-
- public void setParentAttrValue(String parentAttrValue) {
- this.parentAttrValue = parentAttrValue;
- }
-
- public TreeNode[] getChildAttrNode() {
- return childAttrNode;
- }
-
- public void setChildAttrNode(TreeNode[] childAttrNode) {
- this.childAttrNode = childAttrNode;
- }
-
- public ArrayList<String> getDataIndex() {
- return dataIndex;
- }
-
- public void setDataIndex(ArrayList<String> dataIndex) {
- this.dataIndex = dataIndex;
- }
-
- public int getLeafNum() {
- return leafNum;
- }
-
- public void setLeafNum(int leafNum) {
- this.leafNum = leafNum;
- }
-
-
-
- }
决策树类DecisionTree.java:
随机森林算法工具类RandomForestTool.java:
CART算法工具类CARTTool.java:
- package DataMining_RandomForest;
-
- import java.util.ArrayList;
- import java.util.HashMap;
- import java.util.LinkedList;
- import java.util.Queue;
-
-
-
-
-
-
-
- public class CARTTool {
-
- private final String YES = "Yes";
- private final String NO = "No";
-
-
- private int attrNum;
- private String filePath;
-
- private String[][] data;
-
- private String[] attrNames;
-
- private HashMap<String, ArrayList<String>> attrValue;
-
- public CARTTool(ArrayList<String[]> dataArray) {
- attrValue = new HashMap<>();
- readData(dataArray);
- }
-
-
-
-
-
-
- public void readData(ArrayList<String[]> dataArray) {
- data = new String[dataArray.size()][];
- dataArray.toArray(data);
- attrNum = data[0].length;
- attrNames = data[0];
- }
-
-
-
-
- public void initAttrValue() {
- ArrayList<String> tempValues;
-
-
- for (int j = 1; j < attrNum; j++) {
-
- tempValues = new ArrayList<>();
- for (int i = 1; i < data.length; i++) {
- if (!tempValues.contains(data[i][j])) {
-
- tempValues.add(data[i][j]);
- }
- }
-
-
- attrValue.put(data[0][j], tempValues);
- }
- }
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- public double computeGini(String[][] remainData, String attrName,
- String value, boolean beLongValue) {
-
- int total = 0;
-
- int posNum = 0;
-
- int negNum = 0;
-
- double gini = 0;
-
-
- for (int j = 1; j < attrNames.length; j++) {
-
- if (attrName.equals(attrNames[j])) {
- for (int i = 1; i < remainData.length; i++) {
-
- if ((beLongValue && remainData[i][j].equals(value))
- || (!beLongValue && !remainData[i][j].equals(value))) {
- if (remainData[i][attrNames.length - 1].equals(YES)) {
-
- posNum++;
- } else {
- negNum++;
- }
- }
- }
- }
- }
-
- total = posNum + negNum;
- double posProbobly = (double) posNum / total;
- double negProbobly = (double) negNum / total;
- gini = 1 - posProbobly * posProbobly - negProbobly * negProbobly;
-
-
- return gini;
- }
-
-
-
-
-
-
-
-
-
-
- public String[] computeAttrGini(String[][] remainData, String attrName) {
- String[] str = new String[2];
-
- String spiltValue = "";
-
- int tempNum = 0;
-
- double minGini = Integer.MAX_VALUE;
- ArrayList<String> valueTypes = attrValue.get(attrName);
-
- HashMap<String, Integer> belongNum = new HashMap<>();
-
- for (String string : valueTypes) {
-
- tempNum = 0;
-
- for (int j = 1; j < attrNames.length; j++) {
-
- if (attrName.equals(attrNames[j])) {
- for (int i = 1; i < remainData.length; i++) {
-
- if (remainData[i][j].equals(string)) {
- tempNum++;
- }
- }
- }
- }
-
- belongNum.put(string, tempNum);
- }
-
- double tempGini = 0;
- double posProbably = 1.0;
- double negProbably = 1.0;
- for (String string : valueTypes) {
- tempGini = 0;
-
- posProbably = 1.0 * belongNum.get(string) / (remainData.length - 1);
- negProbably = 1 - posProbably;
-
- tempGini += posProbably
- * computeGini(remainData, attrName, string, true);
- tempGini += negProbably
- * computeGini(remainData, attrName, string, false);
-
- if (tempGini < minGini) {
- minGini = tempGini;
- spiltValue = string;
- }
- }
-
- str[0] = spiltValue;
- str[1] = minGini + "";
-
- return str;
- }
-
- public void buildDecisionTree(TreeNode node, String parentAttrValue,
- String[][] remainData, ArrayList<String> remainAttr,
- boolean beLongParentValue) {
-
- String valueType = "";
-
- String spiltAttrName = "";
- double minGini = Integer.MAX_VALUE;
- double tempGini = 0;
-
- String[] giniArray;
-
- if (beLongParentValue) {
- node.setParentAttrValue(parentAttrValue);
- } else {
- node.setParentAttrValue("!" + parentAttrValue);
- }
-
- if (remainAttr.size() == 0) {
- if (remainData.length > 1) {
- ArrayList<String> indexArray = new ArrayList<>();
- for (int i = 1; i < remainData.length; i++) {
- indexArray.add(remainData[i][0]);
- }
- node.setDataIndex(indexArray);
- }
-
- return;
- }
-
- for (String str : remainAttr) {
- giniArray = computeAttrGini(remainData, str);
- tempGini = Double.parseDouble(giniArray[1]);
-
- if (tempGini < minGini) {
- spiltAttrName = str;
- minGini = tempGini;
- valueType = giniArray[0];
- }
- }
-
- remainAttr.remove(spiltAttrName);
- node.setAttrName(spiltAttrName);
-
-
- TreeNode[] childNode = new TreeNode[2];
- String[][] rData;
-
- boolean[] bArray = new boolean[] { true, false };
- for (int i = 0; i < bArray.length; i++) {
-
- rData = removeData(remainData, spiltAttrName, valueType, bArray[i]);
-
- boolean sameClass = true;
- ArrayList<String> indexArray = new ArrayList<>();
- for (int k = 1; k < rData.length; k++) {
- indexArray.add(rData[k][0]);
-
- if (!rData[k][attrNames.length - 1]
- .equals(rData[1][attrNames.length - 1])) {
-
- sameClass = false;
- break;
- }
- }
-
- childNode[i] = new TreeNode();
- if (!sameClass) {
-
- ArrayList<String> rAttr = new ArrayList<>();
- for (String str : remainAttr) {
- rAttr.add(str);
- }
- buildDecisionTree(childNode[i], valueType, rData, rAttr,
- bArray[i]);
- } else {
- String pAtr = (bArray[i] ? valueType : "!" + valueType);
- childNode[i].setParentAttrValue(pAtr);
- childNode[i].setDataIndex(indexArray);
- }
- }
-
- node.setChildAttrNode(childNode);
- }
-
-
-
-
-
-
-
-
-
-
-
-
- private String[][] removeData(String[][] srcData, String attrName,
- String valueType, boolean beLongValue) {
- String[][] desDataArray;
- ArrayList<String[]> desData = new ArrayList<>();
-
- ArrayList<String[]> selectData = new ArrayList<>();
- selectData.add(attrNames);
-
-
- for (int i = 0; i < srcData.length; i++) {
- desData.add(srcData[i]);
- }
-
-
- for (int j = 1; j < attrNames.length; j++) {
- if (attrNames[j].equals(attrName)) {
- for (int i = 1; i < desData.size(); i++) {
- if (desData.get(i)[j].equals(valueType)) {
-
- selectData.add(desData.get(i));
- }
- }
- }
- }
-
- if (beLongValue) {
- desDataArray = new String[selectData.size()][];
- selectData.toArray(desDataArray);
- } else {
-
- selectData.remove(attrNames);
-
- desData.removeAll(selectData);
- desDataArray = new String[desData.size()][];
- desData.toArray(desDataArray);
- }
-
- return desDataArray;
- }
-
-
-
-
-
- public TreeNode startBuildingTree() {
- initAttrValue();
-
- ArrayList<String> remainAttr = new ArrayList<>();
-
- for (int i = 1; i < attrNames.length - 1; i++) {
- remainAttr.add(attrNames[i]);
- }
-
- TreeNode rootNode = new TreeNode();
- buildDecisionTree(rootNode, "", data, remainAttr, false);
- setIndexAndAlpah(rootNode, 0, false);
- showDecisionTree(rootNode, 1);
-
- return rootNode;
- }
-
-
-
-
-
-
-
-
-
- private void showDecisionTree(TreeNode node, int blankNum) {
- System.out.println();
- for (int i = 0; i < blankNum; i++) {
- System.out.print(" ");
- }
- System.out.print("--");
-
- if (node.getParentAttrValue() != null
- && node.getParentAttrValue().length() > 0) {
- System.out.print(node.getParentAttrValue());
- } else {
- System.out.print("--");
- }
- System.out.print("--");
-
- if (node.getDataIndex() != null && node.getDataIndex().size() > 0) {
- String i = node.getDataIndex().get(0);
- System.out.print("【" + node.getNodeIndex() + "】类别:"
- + data[Integer.parseInt(i)][attrNames.length - 1]);
- System.out.print("[");
- for (String index : node.getDataIndex()) {
- System.out.print(index + ", ");
- }
- System.out.print("]");
- } else {
-
- System.out.print("【" + node.getNodeIndex() + ":"
- + node.getAttrName() + "】");
- if (node.getChildAttrNode() != null) {
- for (TreeNode childNode : node.getChildAttrNode()) {
- showDecisionTree(childNode, 2 * blankNum);
- }
- } else {
- System.out.print("【 Child Null】");
- }
- }
- }
-
-
-
-
-
-
-
-
-
-
-
- private void setIndexAndAlpah(TreeNode node, int index, boolean ifCutNode) {
- TreeNode tempNode;
-
- TreeNode minAlphaNode = null;
- double minAlpah = Integer.MAX_VALUE;
- Queue<TreeNode> nodeQueue = new LinkedList<TreeNode>();
-
- nodeQueue.add(node);
- while (nodeQueue.size() > 0) {
- index++;
-
- tempNode = nodeQueue.poll();
- tempNode.setNodeIndex(index);
- if (tempNode.getChildAttrNode() != null) {
- for (TreeNode childNode : tempNode.getChildAttrNode()) {
- nodeQueue.add(childNode);
- }
- computeAlpha(tempNode);
- if (tempNode.getAlpha() < minAlpah) {
- minAlphaNode = tempNode;
- minAlpah = tempNode.getAlpha();
- } else if (tempNode.getAlpha() == minAlpah) {
-
- if (tempNode.getLeafNum() > minAlphaNode.getLeafNum()) {
- minAlphaNode = tempNode;
- }
- }
- }
- }
-
- if (ifCutNode) {
-
- minAlphaNode.setChildAttrNode(null);
- }
- }
-
-
-
-
-
-
-
- private void computeAlpha(TreeNode node) {
- double rt = 0;
- double Rt = 0;
- double alpha = 0;
-
- int sumNum = 0;
-
- int minNum = 0;
-
- ArrayList<String> dataIndex;
- ArrayList<TreeNode> leafNodes = new ArrayList<>();
-
- addLeafNode(node, leafNodes);
- node.setLeafNum(leafNodes.size());
- for (TreeNode attrNode : leafNodes) {
- dataIndex = attrNode.getDataIndex();
-
- int num = 0;
- sumNum += dataIndex.size();
- for (String s : dataIndex) {
-
- if (data[Integer.parseInt(s)][attrNames.length - 1].equals(YES)) {
- num++;
- }
- }
- minNum += num;
-
-
- if (1.0 * num / dataIndex.size() > 0.5) {
- num = dataIndex.size() - num;
- }
-
- rt += (1.0 * num / (data.length - 1));
- }
-
-
- if (1.0 * minNum / sumNum > 0.5) {
- minNum = sumNum - minNum;
- }
-
- Rt = 1.0 * minNum / (data.length - 1);
- alpha = 1.0 * (Rt - rt) / (leafNodes.size() - 1);
- node.setAlpha(alpha);
- }
-
-
-
-
-
-
-
-
-
- private void addLeafNode(TreeNode node, ArrayList<TreeNode> leafNode) {
- ArrayList<String> dataIndex;
-
- if (node.getChildAttrNode() != null) {
- for (TreeNode childNode : node.getChildAttrNode()) {
- dataIndex = childNode.getDataIndex();
- if (dataIndex != null && dataIndex.size() > 0) {
-
- leafNode.add(childNode);
- } else {
-
- addLeafNode(childNode, leafNode);
- }
- }
- }
- }
-
- }
测试类Client.java:
- package DataMining_RandomForest;
-
- import java.text.MessageFormat;
-
-
-
-
-
-
-
- public class Client {
- public static void main(String[] args) {
- String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt";
- String queryStr = "Age=Youth,Income=Low,Student=No,CreditRating=Fair";
- String resultClassType = "";
-
- double sampleNumRatio = 0.4;
-
- double featureNumRatio = 0.5;
-
- RandomForestTool tool = new RandomForestTool(filePath, sampleNumRatio,
- featureNumRatio);
- tool.constructRandomTree();
-
- resultClassType = tool.judgeClassType(queryStr);
-
- System.out.println();
- System.out
- .println(MessageFormat.format(
- "查询属性描述{0},预测的分类结果为BuysCompute:{1}", queryStr,
- resultClassType));
- }
- }
算法的输出
- 下面是随机森林中的决策树:
-
- 决策树1
-
- --!--【1:Income】
- --Medium--【2】类别:Yes[1, 2, ]
- --!Medium--【3:Student】
- --No--【4】类别:No[3, 5, ]
- --!No--【5】类别:Yes[4, ]
- 决策树2
-
- --!--【1:Student】
- --No--【2】类别:No[1, 3, ]
- --!No--【3】类别:Yes[2, 4, 5, ]
- 查询属性描述Age=Youth,Income=Low,Student=No,CreditRating=Fair,预测的分类结果为BuysCompute:No
输出的结果决策树建议从左往右看,从上往下,【】符号表示一个节点,---XX---表示属性值的划分,你就应该能看懂这棵树了,在console上想展示漂亮的树形效果的确很难。。。这里说一个算法的重大不足,数据太少,导致选择的样本数据不足,所选属性太少,,构造的决策树数量过少,自然分类的准确率不见得会有多准,博友只要能领会代码中所表达的算法的思想即可。
GBDT
下面来说说随机森林的兄弟算法GBDT,梯度提升决策树,他有很多的决策树,他也有组合的思想,但是他不是随机森林算法2,GBDT的关键在于Gradient Boosting,梯度提升。这个词语理解起来就不容易了。学术的描述,每一次建立模型是在之前建立模型的损失函数的梯度下降方向。GBDT的核心在于,每一棵树学的是之前所有树结论和的残差,这个残差你可以理解为与预测值的差值。举个例子:比如预测张三的年龄,张三的真实年龄18岁,第一棵树预测张的年龄12岁,此时残差为18-12=6岁,因此在第二棵树中,我们把张的年龄作为6岁去学习,如果预测成功了,则张的真实年龄就是A树和B树的结果预测值的和,但是如果B预测成了5岁,那么残差就变成了6-5=1岁,那么此时需要构建第三树对1岁做预测,后面一样的道理。每棵树都是对之前失败预测的一个补充,用公式的表达就是如下的这个样子:

F0在这里是初始值,Ti是一棵棵的决策树,不同的问题选择不同的损失函数和初始值。在阿里内部对于此算法的叫法为TreeLink。所以下次听到什么Treelink算法了指的就是梯度提升树算法,其实我在这里省略了很大篇幅的数学推导过程,再加上自己还不是专家,无法彻底解释清数学的部分,所以就没有提及,希望以后有时间可以深入学习此方面的知识。