数据挖掘笔记-分类-Adaboost-原理与简单实现

Adaboost(Adaptive Boosting)是一种原理简单,但很实用的有监督机器学习算法。说Boosting算法之前,先说下Bagging算法,他们两个都是把一些弱分类器组合起来进行分类的方法,统称为集成方法(Ensemble Method),虽然每个弱分类器分类的结果不那么准确,但是如果把多个弱分类器组合起来却可以得到相当不错的结果。同时集成方法还可以组合不同的分类器。Bagging算法大致思路是:从整体样本集合中,抽样n<N个样本,针对抽样的集合训练分类器,然后分类器进行投票,最终的结果是分类器投票的优胜结果。代表例子就是随机森林(Random Forest)。随机森林的每个弱分类器是决策树(Decision Tree),输出的类别由多个决策树分类的类别投票决定。但是,这种方法只是将分类器进行简单的组合,实际上并没有发挥出分类器组合的威力来。直到后来Freund and schapire提出了Adaboost算法,这个简单,有效的算法。Adaboost算法不同于Bagging算法的地方是:AdaBoost算法的每个弱分类器组合起来的权重不一样,而Bagging的每个弱分类器的组合权重是相等的。Adaboost原理如下图:

数据挖掘笔记-分类-Adaboost-原理与简单实现_第1张图片

左边矩形表示数据集,中间方形表示弱分类器,后面的三角形表示每个分类器对应的权重,最后根据每个弱分类器的加权组合来判断总体类别。这是Adaboost算法的整个原理过程。在这个过程中需要注意如下:
1、每次迭代改变的是样本的分布,或者说是样本数据的权重,而不是重复采样。
2、数据集中每个样本数据对应有一个权重,初始化的时候每条样本数据的权重是一样的,一般为(1 / 样本数量)。后面在每次弱分类器训练完数据后会更新一次样本数据的权重。如何更新会在下面说明。
3、弱分类器的权重alpha是根据弱分类器的错误率来调整,计算如下图
数据挖掘笔记-分类-Adaboost-原理与简单实现_第2张图片

弱分类器权重计算完成后可以更新样本数据的权重,其实也可以说样本数据的权重间接靠分类错误率,样本分布的改变取决于样本是否被正确分类,分类正确的样本权值低,分类错误的样本权值高。如下图所示:
正确预测分类:

错误预测分类:

Adaboost算法优点:
1)Adaboost是一种有很高精度的分类器。
2)可以使用各种方法构建子分类器,Adaboost算法提供的仅仅是框架。
3)当使用简单分类器时,计算出的结果是可以理解的。而且弱分类器构造极其简单。
4)不用做特征筛选。
5)不用担心过拟合的问题。


下面是用Java简单实现Adaboost算法的一个例子:
/**
 * Adaboost分类器 
 */
public class AdaboostClassifier {

	public static final int CLASSIFIER_NUM = 40;
	
	/**
	 * 训练数据集,生成若干个弱分类器的训练结果
	 * @param datas
	 * @return
	 */
	public static List<TrainResult> train(List<double[]> datas) {
		List<Double> dataWeights = new ArrayList<Double>();
		for (int i = 0, len = datas.size(); i < len; i++) {
			dataWeights.add(1.0 / len);
		}
		List<TrainResult> trainResults = new ArrayList<TrainResult>();
		for (int n = 0; n < CLASSIFIER_NUM; n++) {
			TrainResult trainResult = WeakClassifier.train(datas, dataWeights);
			trainResults.add(trainResult);
			List<Double> classifyResult = trainResult.getClassifyResult();
			double sum = 0, alpha = trainResult.getClassifierWeight();
			for (int i = 0, len = datas.size(); i < len; i++) {
				double newWeight = dataWeights.get(i) * Math.pow(Math.E, (-1) * alpha * classifyResult.get(i));
				dataWeights.set(i, newWeight);
				sum += newWeight;
			}
			for (int i = 0, len = datas.size(); i < len; i++) {
				dataWeights.set(i, dataWeights.get(i) / sum);
			}
		}
		return trainResults;
	}
	
	public static List<TrainResult> train(List<double[]> datas, List<Double> categories) {
		return null;
	}
	
	/**
	 * 根据训练结果判断数据分类
	 * @param data
	 * @param trainResults
	 * @return
	 */
	public static double classify(double[] data, List<TrainResult> trainResults) {
		double result = 0;
		for (TrainResult trainResult : trainResults) {
			double classifierWeight = trainResult.getClassifierWeight();
			List<Double> featureWeight = trainResult.getFeatureWeight();
			for (int i = 0, len = data.length; i < len; i++) {
				result += data[i] * classifierWeight * featureWeight.get(i);
			}
			result += featureWeight.get(data.length);
		}
		return result > 0 ? 1 : -1;
	}
	
	/**
	 * 根据训练结果判断数据集分类
	 * @param datas
	 * @param trainResults
	 * @return
	 */
	public static List<Double> classify(List<double[]> datas, List<TrainResult> trainResults) {
		List<Double> results = new ArrayList<Double>();
		for (int i = 0, len = datas.size(); i < len; i++) {
			results.add(classify(datas.get(i), trainResults));
		}
		return results;
	}
	
}

/**
 * 弱分类器
 */
public class WeakClassifier {
	
	public static final int NUMBER = 100;
	
	/**
	 * 训练数据集
	 * @param datas
	 * @param dataWeights
	 * @return 训练结果
	 */
	public static TrainResult train(List<double[]> datas, List<Double> dataWeights) {
		TrainResult trainResult = new TrainResult();
		List<Double> featureWeight = perceptron(datas, dataWeights);
		trainResult.setFeatureWeight(featureWeight);
		int dimension = datas.get(0).length;
		int right = 0, error = 0;
		for (int i = 0, len = datas.size(); i < len; i++) {
			double result = 0;
			for (int j = 0; j < dimension - 1; j++) {
				result += datas.get(i)[j] * featureWeight.get(j);
			}
			result += featureWeight.get(dimension - 1);
			if (result >= 0) {
				right++;
				trainResult.getClassifyResult().add(1d);
			} else {
				error++;
				trainResult.getClassifyResult().add(-1d);
			}
		}
		System.out.println("total: " + datas.size() + " right: " + right);
		double errorRate = new Double(error) / datas.size();
		double alpha = Math.log((1 - errorRate) / errorRate);
		trainResult.setClassifierWeight(alpha);
		return trainResult;
	}
	
	/**
	 * 传感器算法迭代计算特征权重
	 * @param datas
	 * @param dataWeights
	 * @return
	 */
	public static List<Double> perceptron(List<double[]> datas, List<Double> dataWeights) {
		int dimension = datas.get(0).length;
		List<Double> featureWeights = new ArrayList<Double>();
		for (int i = 0; i < dimension; i++) {
			featureWeights.add(1d);
		}
		for (int n = 0; n < NUMBER; n++) {
			for (int i = 0, len = datas.size(); i < len; i++) {
				double result = 0;
				for (int j = 0; j < dimension; j++) {
					result += datas.get(i)[j] * featureWeights.get(j);
				}
				if (result <= 0) {
 					for (int k = 0; k < dimension; k++) {
						double newWeight = featureWeights.get(k) + datas.get(i)[k] * dataWeights.get(i);
						featureWeights.set(k, newWeight);
					}
				} 
			}
		}
		return featureWeights;
	}
	
}

public class AdaboostBuilder {
	
	public static DataSet initTrainSet() {
		DataSet dataSet = null;
		try {
			String path = AdaboostBuilder.class.getClassLoader().getResource("trainset/adaboost1.txt").toURI().getPath();
			dataSet = initTrainSet(path);
		} catch (URISyntaxException e) {
			e.printStackTrace();
		}
		return dataSet;
	}
	
	public static DataSet initTrainSet(String path) {
		DataSet dataSet = new DataSet();
		InputStream in = null;
		BufferedReader reader = null;
		try {
			in = new FileInputStream(new File(path));
			reader = new BufferedReader(new InputStreamReader(in));
			String line = reader.readLine();
			while (!("").equals(line) && null != line) {
				StringTokenizer tokenizer = new StringTokenizer(line);
				double[] data = new double[3];
				data[0] = Double.parseDouble(tokenizer.nextToken());
				data[1] = Double.parseDouble(tokenizer.nextToken());
				String txt = tokenizer.nextToken();
				data[2] = Double.parseDouble(txt);
				dataSet.getDatas().add(data);
				dataSet.getCategories().add(Double.parseDouble(txt));
				line = reader.readLine();
			}
		} catch (Exception e) {
			e.printStackTrace();
		} finally {
			IOUtils.closeQuietly(in);
			IOUtils.closeQuietly(reader);
		}
		return dataSet;
	}
	
	public static DataSet initTestSet() {
		DataSet dataSet = null;
		try {
			String path = AdaboostBuilder.class.getClassLoader().getResource("testset/adaboost1.txt").toURI().getPath();
			dataSet = initTestSet(path);
		} catch (URISyntaxException e) {
			e.printStackTrace();
		}
		return dataSet;
	}
	
	public static DataSet initTestSet(String path) {
		DataSet dataSet = new DataSet();
		InputStream in = null;
		BufferedReader reader = null;
		try {
			in = new FileInputStream(new File(path));
			reader = new BufferedReader(new InputStreamReader(in));
			String line = reader.readLine();
			while (!("").equals(line) && null != line) {
				StringTokenizer tokenizer = new StringTokenizer(line);
				double[] data = new double[2];
				data[0] = Double.parseDouble(tokenizer.nextToken());
				data[1] = Double.parseDouble(tokenizer.nextToken());
				dataSet.getDatas().add(data);
				dataSet.getCategories().add(Double.parseDouble(tokenizer.nextToken()));
				line = reader.readLine();
			}
		} catch (Exception e) {
			e.printStackTrace();
		} finally {
			IOUtils.closeQuietly(in);
			IOUtils.closeQuietly(reader);
		}
		return dataSet;
	}
	
	public void run() {
			DataSet trainSet = initTrainSet();
			List<TrainResult> trainResults = AdaboostClassifier.train(trainSet.getDatas());
			DataSet testSet = initTestSet();
			List<Double> predictedValues = AdaboostClassifier.classify(testSet.getDatas(), trainResults);
			DataResult dataResult = new DataResult();
			dataResult.setDatas(testSet.getDatas());
			dataResult.setActualValues(testSet.getCategories());
			dataResult.setPredictedValues(predictedValues);
			dataResult.statistics();
	}
	
	public static void main(String[] args) {
		new AdaboostBuilder().run();
	}
	
}


代码托管:https://github.com/fighting-one-piece/repository-datamining.git



你可能感兴趣的:(数据挖掘,分类,adaboost,boosting,bagging)