日撸代码300行:第61-62天(决策树算法)

代码来自闵老师”日撸 Java 三百行(61-70天)

日撸 Java 三百行(61-70天,决策树与集成学习)_闵帆的博客-CSDN博客

package machinelearning.decisiontree;

import java.io.FileReader;
import java.util.Arrays;

import weka.core.Instance;
import weka.core.Instances;

/**
 * The ID3 decision tree inductive algorithm.
 * 
 * @author WX873
 *
 */
public class ID3 {
	/**
	 * The data
	 */
	Instances dataset;
	
	/**
	 * Is this data set pure (only label)?
	 */
	boolean pure;
	
	/**
	 * The number of classes. For binary classification it is 2.
	 */
	int numClasses;
	
	/**
	 * Available instances. Other instances do not belong this branch.
	 */
	int[] availableInstances;
	
	/**
	 * Available attributes. Other attributes have been selected in the path from the root.
	 */
	int[] availableAttributes;
	
	/**
	 * The selected attribute.
	 */
	int splitAttribute;
	
	/**
	 * The children nodes.
	 */
	ID3[] children;
	
	/**
	 * 
	 * My label. Inner nodes also have a label. For example, 
	 * never appear in the training data, but  is valid in other cases.
	 * 
	 */
	int label;
	
	/**
	 * The prediction, including queried and predicted labels.
	 */
	int[] predicts;
	
	/**
	 * Small block cannot be split further.
	 */
	static int smallBlockThreshold = 3;
	
	/**
	 * ******************************************************
	 * The first constructor.
	 * 
	 * @param paraFilename  The given file.
	 * ******************************************************
	 */
	public ID3(String paraFilename) {
		// TODO Auto-generated constructor stub
		try {
			FileReader fileReader =new FileReader(paraFilename);
			dataset = new Instances(fileReader);
			fileReader.close();
		} catch (Exception e) {
			// TODO: handle exception
			System.out.println("Cannot read the file: " + paraFilename + "\r\n" + e);
			System.exit(0);
		}//of try
		
		dataset.setClassIndex(dataset.numAttributes() - 1);
		numClasses = dataset.classAttribute().numValues();
		
		availableInstances = new int[dataset.numInstances()];
		for (int i = 0; i < availableInstances.length; i++) {
			availableInstances[i] = i;
		}//of for i
		
		availableAttributes = new int[dataset.numAttributes() - 1];
		for (int i = 0; i < availableAttributes.length; i++) {
			availableAttributes[i] = i;
		}//of for i
		
		//Initialize.
		children = null;
		
		//Determine the label by simple voting.
		label = getMajorityClass(availableInstances);
		
		//Determine whether or not it is pure.
		pure = pureJudge(availableInstances);
		
	}//Of the first constructor
	
	/**
	 * ************************************************************
	 * The constructor.
	 * 
	 * @param paraDataset   The given data set.
	 * @param paraAvailableInstances   The given available instances.
	 * @param paraAvailableAttributes  The given available attributes.
	 * ************************************************************
	 */
	public ID3(Instances paraDataset, int[] paraAvailableInstances, int[] paraAvailableAttributes) {
		// TODO Auto-generated constructor stub
		dataset = paraDataset;
		availableInstances = paraAvailableInstances;
		availableAttributes = paraAvailableAttributes;
		
		//Initialize.
		children = null;
		
		//Determine the label by simple voting.
		label = getMajorityClass(availableInstances);
		
		//Determine whether or not it is pure.
		pure = pureJudge(availableInstances);
	}//Of the second constructor
	
	/**
	 * *****************************************************
	 * Judge is the given block pure?
	 * @param paraBlock  The block which needs judge.
	 * @return  True if pure.
	 * *****************************************************
	 */
	public boolean pureJudge(int[] paraBlock) {
		pure = true;
		
		for (int i = 1; i < paraBlock.length; i++) {
			if (dataset.instance(paraBlock[i]).classValue() != dataset.instance(paraBlock[0]).classValue()) {
				pure = false;
				break;
			}//of if
		}//of for i
		return pure;
	}//of pureJudge
	
	/**
	 * **********************************************************
	 * Compute the majority class of the given block for voting.
	 * 
	 * @param paraBlock   The given block.
	 * @return      The majority class.
	 * **********************************************************
	 */
	public int getMajorityClass(int[] paraBlock) {
		int[] tempClassConunts = new int[dataset.numClasses()];
		for (int i = 0; i < paraBlock.length; i++) {
			tempClassConunts[(int) dataset.instance(paraBlock[i]).classValue()]++;
		}//of for i
		
		int resultMajorityClass = -1;
		int tempMaxConut = -1;
		for (int i = 0; i < tempClassConunts.length; i++) {
			if (tempMaxConut < tempClassConunts[i]) {
				resultMajorityClass = i;
				tempMaxConut = tempClassConunts[i];
			}//of if
		}//of for i
		
		return resultMajorityClass;
	}//of getMajorityClass
	
	/**
	 * *******************************************************
	 * Select the best attribute.
	 * @return  The best attribute index.
	 * *******************************************************
	 */
	public int selectBestAttribute() {
		splitAttribute = -1;
		double tempMinimalEntropy = 10000;
		double tempEntropy;
		for (int i = 0; i < availableAttributes.length; i++) {
			tempEntropy = conditionlEntropy(availableAttributes[i]);
			if (tempEntropy < tempMinimalEntropy) {
				tempMinimalEntropy = tempEntropy;
				splitAttribute = availableAttributes[i];
			}//of if
		}//of for i
		
		return splitAttribute;
	}//of selectBestAttribute
	
	/**
	 * *****************************************************
	 * Compute the conditional entropy of an attribute.
	 * 
	 * @param paraAttribute  The given attribute.
	 * @return   The entropy.
	 * *****************************************************
	 */
	public double conditionlEntropy(int paraAttribute) {
		//Step 1. Statistics.
		int tempNumClasses = dataset.numClasses();
		int tempNumValues = dataset.attribute(paraAttribute).numValues();
		int tempNumInstances = availableInstances.length;
		double[] tempValueCounts = new double[tempNumValues];
		double[][] tempCountMatrix = new double[tempNumValues][tempNumClasses];
		
		int tempClass, tempValue;
		for (int i = 0; i < tempNumInstances; i++) {
			tempClass = (int)dataset.instance(availableInstances[i]).classValue();
			tempValue = (int)dataset.instance(availableInstances[i]).value(paraAttribute);
			tempValueCounts[tempValue]++;
			tempCountMatrix[tempValue][tempClass]++;
		}//of for i
		
		//Step 2. Compute Entropy.
		double resultEntropy = 0;
		double tempEntropy, tempFraction ;
		
		for (int i = 0; i < tempNumValues; i++) {
			if (tempValueCounts[i] == 0) {
				continue;  //计数为0的跳过
			}//of if
			tempEntropy = 0;
			for (int j = 0; j < tempNumClasses; j++) {
				tempFraction = tempCountMatrix[i][j] / tempValueCounts[i];
				if (tempFraction == 0) {
					continue;
				}//of if
				tempEntropy += -tempFraction * Math.log(tempFraction);
			}//of for j
				resultEntropy += tempValueCounts[i] / tempNumInstances * tempEntropy;
		}//of for i
		
		return resultEntropy;
	}//of conditionlEntropy
	
	/**
	 * ********************************************************************
	 * Split the data according to the given attribute.
	 * @param paraAttribute   The given attribute.
	 * @return  The blocks.
	 * ********************************************************************
	 */
	public int[][] splitData(int paraAttribute) {
		int tempNumValues = dataset.attribute(paraAttribute).numValues();
		int[][] resultBlocks = new int[tempNumValues][]; 
		int[] tempSizes = new int[tempNumValues];
		
		//First scan to count the size of each block.
		int tempValue;
		for (int i = 0; i < availableInstances.length; i++) {
			tempValue = (int)dataset.instance(availableInstances[i]).value(paraAttribute);
			tempSizes[tempValue]++;
		}//of for i
		
		//Allocate space.
		for (int i = 0; i < tempNumValues; i++) {
			resultBlocks[i] = new int[tempSizes[i]];
		}//of for i
		
		// Second scan to fill.
		Arrays.fill(tempSizes, 0);
		for (int i = 0; i < availableInstances.length; i++) {
			tempValue = (int)dataset.instance(availableInstances[i]).value(paraAttribute);
			resultBlocks[tempValue][tempSizes[tempValue]] = availableInstances[i];
			tempSizes[tempValue]++;
		}//of for i
		
		return resultBlocks;
	}//of spliData
	
	
	public void buildTree() {
		if (pureJudge(availableInstances)) {
			return;
		}//of if
		if (availableInstances.length <= smallBlockThreshold) {
			return;
		}//of if
		
		selectBestAttribute();
		int[][] tempSubBlocks = splitData(splitAttribute);
		children = new ID3[tempSubBlocks.length];
		
		//Construct the remaining attribute set.
		int[] tempRemainingAttributes = new int[availableAttributes.length - 1];
		for (int i = 0; i < availableAttributes.length; i++) {
			//将等于splitAttribute的属性覆盖,小于该值的不变,大于该值的前移一个位置
			if (availableAttributes[i] < splitAttribute) {
				tempRemainingAttributes[i] = availableAttributes[i];
			}else if(availableAttributes[i] > splitAttribute){
				tempRemainingAttributes[i - 1] = availableAttributes[i];
			}//of if
		}//of for i
		
		//Construct children.
		for (int i = 0; i < children.length; i++) {
			if ((tempSubBlocks[i] == null)||(tempSubBlocks[i].length == 0)) {
				children[i] =null;
				continue;
			} else {
				children[i] = new ID3(dataset, tempSubBlocks[i], tempRemainingAttributes);
				
				// Important code: do this recursively
				children[i].buildTree();
			}//of if
		}//of for i
		
	}//of buildTree
	
	/**
	 * **********************************************************
	 * Classify an instance.
	 * @param paraInstance   The given instance.
	 * @return  The prediction.
	 * **********************************************************
	 */
	public int classify(Instance paraInstance) {
		if (children ==null) {
			return label;
		}//of if
		
		ID3 tempChild = children[(int)paraInstance.value(splitAttribute)];
		if (children ==null) {
			return label;
		}//of if
		
		return tempChild.classify(paraInstance); 
	}//of classify
	
	/**
	 * *************************************************************
	 * Compute the accuracy.
	 * @param paraDataset  The given testing data.
	 * @return
	 * *************************************************************
	 */
	public double test(Instances paraDataset) {
		double tempCorrect = 0;
		for (int i = 0; i < paraDataset.numInstances(); i++) {
			if (classify(paraDataset.instance(i))==(int)paraDataset.instance(i).classValue()) {
				tempCorrect++;
			}//of if
		}//of for i
		
		return tempCorrect/paraDataset.numInstances();
	}//of test
	
	/**
	 * *******************************************************
	 * Test on the training set.
	 * @return  The accuracy.
	 * *******************************************************
	 */
	public double selfTest() {
		return test(dataset);
	}//of selfTest
	
	/**
	 * **********************************************
	 * Overrides the method claimed in Object.
	 * 
	 * @return The tree structure.
	 * **********************************************
	 */
	public String toString() {
		String resultString = "";
		String tempAttributeName = dataset.attribute(splitAttribute).name();
		if (children == null) {
			resultString += "class = " + label;
		}else {
			for (int i = 0; i < children.length; i++) {
				if (children[i] == null) {
					resultString += tempAttributeName + " = " + dataset.attribute(splitAttribute).value(i) + ":" + "class" + label + "\r\n"; 
				} else {
					resultString += tempAttributeName + " = " + dataset.attribute(splitAttribute).value(i) + ":" + children[i] + "\r\n";
				}//of if
			}//of for i
		}//of if
		
		return resultString;
	}//of toString
	
	/**
	 * *****************************************************
	 * Test this class.
	 * 
	 * @param args  Not used now.
	 * ****************************************************
	 */
	public static void id3Test() {
		ID3 tempID3 = new ID3("E:/Datasets/UCIdatasets/chess.arff");
		ID3.smallBlockThreshold = 3;
		tempID3.buildTree();
		
		System.out.println("The tree is: \r\n" + tempID3);
		
		double tempAccuracy = tempID3.selfTest();
		System.out.println("The accuracy is: " + tempAccuracy);
	}//of id3Test
	
	/**
	 * **************************************************
	 * The entrance of the program.
	 * 
	 * @param args
	 * **************************************************
	 */
	public static void main(String args[]) {
		id3Test();
	}//of main
	
}//of class ID3

前几个方法都比较简单,也容易理解,包括计算熵值和投票决定标签等。比较难理解的是树的构建。刚开始的时候没理解,为什么第二个构造函数传入的第一个参数是整个数据集。因为算法是从第一个构造函数读入数据集的,所以为了启动算法,第二个构造函数需要传入整个数据集。

算法的精度有点儿高,试了三四个UCI公开数据集,精度都能达到0.99。有点儿吓人,不晓得是不是过拟合了。

你可能感兴趣的:(算法,决策树,python)