代码来自闵老师”日撸 Java 三百行(61-70天)
日撸 Java 三百行(61-70天,决策树与集成学习)_闵帆的博客-CSDN博客
package machinelearning.decisiontree;
import java.io.FileReader;
import java.util.Arrays;
import weka.core.Instance;
import weka.core.Instances;
/**
* The ID3 decision tree inductive algorithm.
*
* @author WX873
*
*/
public class ID3 {
/**
* The data
*/
Instances dataset;
/**
* Is this data set pure (only label)?
*/
boolean pure;
/**
* The number of classes. For binary classification it is 2.
*/
int numClasses;
/**
* Available instances. Other instances do not belong this branch.
*/
int[] availableInstances;
/**
* Available attributes. Other attributes have been selected in the path from the root.
*/
int[] availableAttributes;
/**
* The selected attribute.
*/
int splitAttribute;
/**
* The children nodes.
*/
ID3[] children;
/**
*
* My label. Inner nodes also have a label. For example,
* never appear in the training data, but is valid in other cases.
*
*/
int label;
/**
* The prediction, including queried and predicted labels.
*/
int[] predicts;
/**
* Small block cannot be split further.
*/
static int smallBlockThreshold = 3;
/**
* ******************************************************
* The first constructor.
*
* @param paraFilename The given file.
* ******************************************************
*/
public ID3(String paraFilename) {
// TODO Auto-generated constructor stub
try {
FileReader fileReader =new FileReader(paraFilename);
dataset = new Instances(fileReader);
fileReader.close();
} catch (Exception e) {
// TODO: handle exception
System.out.println("Cannot read the file: " + paraFilename + "\r\n" + e);
System.exit(0);
}//of try
dataset.setClassIndex(dataset.numAttributes() - 1);
numClasses = dataset.classAttribute().numValues();
availableInstances = new int[dataset.numInstances()];
for (int i = 0; i < availableInstances.length; i++) {
availableInstances[i] = i;
}//of for i
availableAttributes = new int[dataset.numAttributes() - 1];
for (int i = 0; i < availableAttributes.length; i++) {
availableAttributes[i] = i;
}//of for i
//Initialize.
children = null;
//Determine the label by simple voting.
label = getMajorityClass(availableInstances);
//Determine whether or not it is pure.
pure = pureJudge(availableInstances);
}//Of the first constructor
/**
* ************************************************************
* The constructor.
*
* @param paraDataset The given data set.
* @param paraAvailableInstances The given available instances.
* @param paraAvailableAttributes The given available attributes.
* ************************************************************
*/
public ID3(Instances paraDataset, int[] paraAvailableInstances, int[] paraAvailableAttributes) {
// TODO Auto-generated constructor stub
dataset = paraDataset;
availableInstances = paraAvailableInstances;
availableAttributes = paraAvailableAttributes;
//Initialize.
children = null;
//Determine the label by simple voting.
label = getMajorityClass(availableInstances);
//Determine whether or not it is pure.
pure = pureJudge(availableInstances);
}//Of the second constructor
/**
* *****************************************************
* Judge is the given block pure?
* @param paraBlock The block which needs judge.
* @return True if pure.
* *****************************************************
*/
public boolean pureJudge(int[] paraBlock) {
pure = true;
for (int i = 1; i < paraBlock.length; i++) {
if (dataset.instance(paraBlock[i]).classValue() != dataset.instance(paraBlock[0]).classValue()) {
pure = false;
break;
}//of if
}//of for i
return pure;
}//of pureJudge
/**
* **********************************************************
* Compute the majority class of the given block for voting.
*
* @param paraBlock The given block.
* @return The majority class.
* **********************************************************
*/
public int getMajorityClass(int[] paraBlock) {
int[] tempClassConunts = new int[dataset.numClasses()];
for (int i = 0; i < paraBlock.length; i++) {
tempClassConunts[(int) dataset.instance(paraBlock[i]).classValue()]++;
}//of for i
int resultMajorityClass = -1;
int tempMaxConut = -1;
for (int i = 0; i < tempClassConunts.length; i++) {
if (tempMaxConut < tempClassConunts[i]) {
resultMajorityClass = i;
tempMaxConut = tempClassConunts[i];
}//of if
}//of for i
return resultMajorityClass;
}//of getMajorityClass
/**
* *******************************************************
* Select the best attribute.
* @return The best attribute index.
* *******************************************************
*/
public int selectBestAttribute() {
splitAttribute = -1;
double tempMinimalEntropy = 10000;
double tempEntropy;
for (int i = 0; i < availableAttributes.length; i++) {
tempEntropy = conditionlEntropy(availableAttributes[i]);
if (tempEntropy < tempMinimalEntropy) {
tempMinimalEntropy = tempEntropy;
splitAttribute = availableAttributes[i];
}//of if
}//of for i
return splitAttribute;
}//of selectBestAttribute
/**
* *****************************************************
* Compute the conditional entropy of an attribute.
*
* @param paraAttribute The given attribute.
* @return The entropy.
* *****************************************************
*/
public double conditionlEntropy(int paraAttribute) {
//Step 1. Statistics.
int tempNumClasses = dataset.numClasses();
int tempNumValues = dataset.attribute(paraAttribute).numValues();
int tempNumInstances = availableInstances.length;
double[] tempValueCounts = new double[tempNumValues];
double[][] tempCountMatrix = new double[tempNumValues][tempNumClasses];
int tempClass, tempValue;
for (int i = 0; i < tempNumInstances; i++) {
tempClass = (int)dataset.instance(availableInstances[i]).classValue();
tempValue = (int)dataset.instance(availableInstances[i]).value(paraAttribute);
tempValueCounts[tempValue]++;
tempCountMatrix[tempValue][tempClass]++;
}//of for i
//Step 2. Compute Entropy.
double resultEntropy = 0;
double tempEntropy, tempFraction ;
for (int i = 0; i < tempNumValues; i++) {
if (tempValueCounts[i] == 0) {
continue; //计数为0的跳过
}//of if
tempEntropy = 0;
for (int j = 0; j < tempNumClasses; j++) {
tempFraction = tempCountMatrix[i][j] / tempValueCounts[i];
if (tempFraction == 0) {
continue;
}//of if
tempEntropy += -tempFraction * Math.log(tempFraction);
}//of for j
resultEntropy += tempValueCounts[i] / tempNumInstances * tempEntropy;
}//of for i
return resultEntropy;
}//of conditionlEntropy
/**
* ********************************************************************
* Split the data according to the given attribute.
* @param paraAttribute The given attribute.
* @return The blocks.
* ********************************************************************
*/
public int[][] splitData(int paraAttribute) {
int tempNumValues = dataset.attribute(paraAttribute).numValues();
int[][] resultBlocks = new int[tempNumValues][];
int[] tempSizes = new int[tempNumValues];
//First scan to count the size of each block.
int tempValue;
for (int i = 0; i < availableInstances.length; i++) {
tempValue = (int)dataset.instance(availableInstances[i]).value(paraAttribute);
tempSizes[tempValue]++;
}//of for i
//Allocate space.
for (int i = 0; i < tempNumValues; i++) {
resultBlocks[i] = new int[tempSizes[i]];
}//of for i
// Second scan to fill.
Arrays.fill(tempSizes, 0);
for (int i = 0; i < availableInstances.length; i++) {
tempValue = (int)dataset.instance(availableInstances[i]).value(paraAttribute);
resultBlocks[tempValue][tempSizes[tempValue]] = availableInstances[i];
tempSizes[tempValue]++;
}//of for i
return resultBlocks;
}//of spliData
public void buildTree() {
if (pureJudge(availableInstances)) {
return;
}//of if
if (availableInstances.length <= smallBlockThreshold) {
return;
}//of if
selectBestAttribute();
int[][] tempSubBlocks = splitData(splitAttribute);
children = new ID3[tempSubBlocks.length];
//Construct the remaining attribute set.
int[] tempRemainingAttributes = new int[availableAttributes.length - 1];
for (int i = 0; i < availableAttributes.length; i++) {
//将等于splitAttribute的属性覆盖,小于该值的不变,大于该值的前移一个位置
if (availableAttributes[i] < splitAttribute) {
tempRemainingAttributes[i] = availableAttributes[i];
}else if(availableAttributes[i] > splitAttribute){
tempRemainingAttributes[i - 1] = availableAttributes[i];
}//of if
}//of for i
//Construct children.
for (int i = 0; i < children.length; i++) {
if ((tempSubBlocks[i] == null)||(tempSubBlocks[i].length == 0)) {
children[i] =null;
continue;
} else {
children[i] = new ID3(dataset, tempSubBlocks[i], tempRemainingAttributes);
// Important code: do this recursively
children[i].buildTree();
}//of if
}//of for i
}//of buildTree
/**
* **********************************************************
* Classify an instance.
* @param paraInstance The given instance.
* @return The prediction.
* **********************************************************
*/
public int classify(Instance paraInstance) {
if (children ==null) {
return label;
}//of if
ID3 tempChild = children[(int)paraInstance.value(splitAttribute)];
if (children ==null) {
return label;
}//of if
return tempChild.classify(paraInstance);
}//of classify
/**
* *************************************************************
* Compute the accuracy.
* @param paraDataset The given testing data.
* @return
* *************************************************************
*/
public double test(Instances paraDataset) {
double tempCorrect = 0;
for (int i = 0; i < paraDataset.numInstances(); i++) {
if (classify(paraDataset.instance(i))==(int)paraDataset.instance(i).classValue()) {
tempCorrect++;
}//of if
}//of for i
return tempCorrect/paraDataset.numInstances();
}//of test
/**
* *******************************************************
* Test on the training set.
* @return The accuracy.
* *******************************************************
*/
public double selfTest() {
return test(dataset);
}//of selfTest
/**
* **********************************************
* Overrides the method claimed in Object.
*
* @return The tree structure.
* **********************************************
*/
public String toString() {
String resultString = "";
String tempAttributeName = dataset.attribute(splitAttribute).name();
if (children == null) {
resultString += "class = " + label;
}else {
for (int i = 0; i < children.length; i++) {
if (children[i] == null) {
resultString += tempAttributeName + " = " + dataset.attribute(splitAttribute).value(i) + ":" + "class" + label + "\r\n";
} else {
resultString += tempAttributeName + " = " + dataset.attribute(splitAttribute).value(i) + ":" + children[i] + "\r\n";
}//of if
}//of for i
}//of if
return resultString;
}//of toString
/**
* *****************************************************
* Test this class.
*
* @param args Not used now.
* ****************************************************
*/
public static void id3Test() {
ID3 tempID3 = new ID3("E:/Datasets/UCIdatasets/chess.arff");
ID3.smallBlockThreshold = 3;
tempID3.buildTree();
System.out.println("The tree is: \r\n" + tempID3);
double tempAccuracy = tempID3.selfTest();
System.out.println("The accuracy is: " + tempAccuracy);
}//of id3Test
/**
* **************************************************
* The entrance of the program.
*
* @param args
* **************************************************
*/
public static void main(String args[]) {
id3Test();
}//of main
}//of class ID3
前几个方法都比较简单,也容易理解,包括计算熵值和投票决定标签等。比较难理解的是树的构建。刚开始的时候没理解,为什么第二个构造函数传入的第一个参数是整个数据集。因为算法是从第一个构造函数读入数据集的,所以为了启动算法,第二个构造函数需要传入整个数据集。
算法的精度有点儿高,试了三四个UCI公开数据集,精度都能达到0.99。有点儿吓人,不晓得是不是过拟合了。