【机器学习】决策树

文章目录

    • 一. 决策树算法简介
    • 二. 决策树构建步骤
    • 三. 特征说明
      • 3.1 信息增益(Information Gain, IG)
      • 3.2 基尼不纯度(Gini Impurity)
    • 四. 剪枝策略
    • 五. 决策树的评估
    • 六. 代码实践
      • 例1:决策树分类
      • 例2:决策树回归
      • 补充:可视化
      • 例3:剪枝策略

一. 决策树算法简介

决策树(Decision Tree)用于分类和回归任务。它通过构建树状模型来进行决策。

决策树算法的基本思想是基于数据特征进行递归分裂。每个内部节点代表一个特征上的测试,每个分支代表测试的结果,而每个叶节点代表一个类别(在分类树中)或一个具体的数值(在回归树中)。

决策树算法的优点包括:模型易于理解和解释,不需要对数据进行太多预处理,可以处理数值型和类别型数据。

不过,决策树也有限制,比如容易过拟合,尤其是当树变得很复杂时。为了解决这个问题,通常需要剪枝或者通过集成方法(如随机森林)来提高模型的泛化能力。

二. 决策树构建步骤

  1. 选择最佳分裂特征:通过计算每个特征的信息增益(IG)基尼不纯度(Gini Impurity),选择最能提高决策树准确度的特征。
  2. 决定分裂点:根据选定的特征确定最佳分裂点。
  3. 构建节点:创建决策树的节点,将数据集分裂为子集。
  4. 递归构建树:对每个子集重复步骤1到3,直到满足停止条件(如达到最大深度、节点中的样本数小于最小分裂样本数等)。

三. 特征说明

3.1 信息增益(Information Gain, IG)

信息增益是决策树算法中用于选择数据集分裂特征的一种方法。它基于熵(entropy),熵是度量数据集纯净度的一个指标,用于度量数据集中的混乱程度。信息增益则是指在知道某特征之后带来的熵的减少(即纯净度的增加)。

在构建决策树时,算法会计算每个特征的信息增益,并选择信息增益最大的特征作为节点分裂的特征。具体来说,信息增益定义为父节点的熵与加权平均后子节点熵的差值。

给定一个数据集(D),其熵可以表示为:

E n t r o p y ( D ) = − ∑ k = 1 K p k log ⁡ 2 ( p k ) Entropy(D) = -\sum_{k=1}^{K} p_k \log_2(p_k) Entropy(D)=k=1Kpklog2(pk)

其中,(K)是类别的数量,(p_k)是属于第(k)类的样本所占的比例。信息增益则计算为:

I G ( D , A ) = E n t r o p y ( D ) − ∑ v ∈ V a l u e s ( A ) ∣ D v ∣ ∣ D ∣ E n t r o p y ( D v ) IG(D, A) = Entropy(D) - \sum_{v \in Values(A)} \frac{|D_v|}{|D|} Entropy(D_v) IG(D,A)=Entropy(D)vValues(A)DDvEntropy(Dv)

这里,(A)表示一个特征,(Values(A))是该特征所有可能的值,(D_v)是在特征(A)的值为(v)时的子集,而(|D_v|/|D|)则是权重。

3.2 基尼不纯度(Gini Impurity)

基尼不纯度是另一种度量数据集纯净度的方法,常用于CART(Classification and Regression Trees)算法。与熵类似,基尼不纯度衡量的是一个随机选中的样本被错误标记的概率。

基尼不纯度越小,数据集的纯净度越高。对于给定的数据集(D),其基尼不纯度可以表示为:

G i n i ( D ) = 1 − ∑ k = 1 K p k 2 Gini(D) = 1 - \sum_{k=1}^{K} p_k^2 Gini(D)=1k=1Kpk2

同样地,在构建决策树时,算法会寻找可以最大程度降低基尼不纯度的特征。

在选择最佳分裂特征时,不同的算法可能会偏好信息增益或基尼不纯度。信息增益倾向于选择拥有更多值的特征,而基尼不纯度则是一种更加平衡的选择标准,因为它不倾向于选择那些拥有大量唯一值的特征(如ID或唯一标识符)。

通过使用信息增益或基尼不纯度,决策树算法可以有效地选择出最合适的特征进行分裂,以构建出既能反映数据特征又具有较好泛化能力的模型。

关于更多基尼不纯度的介绍,也可以参考博客园的这篇基尼不纯度简介。

四. 剪枝策略

为了避免过拟合,决策树算法中引入了剪枝策略,包括预剪枝和后剪枝。预剪枝通过提前停止树的构建来防止过拟合;后剪枝则是先构建树,然后删除那些提升模型泛化能力的子树。

Scikit-learn的决策树模型提供了多个参数来控制树的生长,以此来实现预剪枝,但它不直接支持后剪枝操作。

五. 决策树的评估

决策树模型的评估可以通过多种方式进行,如准确率、召回率、F1分数、以及通过交叉验证来评估模型的泛化能力。

六. 代码实践

例1:决策树分类

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score

# 加载iris数据集
iris = load_iris()
X = iris.data
y = iris.target

# 拆分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 创建决策树分类器实例
dtree = DecisionTreeClassifier(max_depth=3, random_state=42)

# 在训练集上训练决策树分类器
dtree.fit(X_train, y_train)

# 对测试集进行预测
predictions = dtree.predict(X_test)

# 计算准确率
accuracy = accuracy_score(y_test, predictions)
print(f'决策树分类器的准确率: {accuracy}')

例2:决策树回归

from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import mean_squared_error
"""
使用加利福尼亚州房价数据集来代替经典的Boston房价数据集.
load_boston函数已经从scikit-learn 1.2版本开始被移除了,原因是Boston房价数据集存在道德问题。
该数据集的创建者在变量"B"的设计中假设种族自我隔离对房价有正面影响,这一假设及其研究目的受到了质疑。
"""
# 加载加利福尼亚州房价数据集
housing = fetch_california_housing()
X = housing.data
y = housing.target

# 拆分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 创建决策树回归器实例
dtree_reg = DecisionTreeRegressor(max_depth=3, random_state=42)

# 在训练集上训练决策树回归器
dtree_reg.fit(X_train, y_train)

# 对测试集进行预测
predictions = dtree_reg.predict(X_test)

# 计算均方误差
mse = mean_squared_error(y_test, predictions)
print(f'决策树回归器的均方误差: {mse}')

通过调整决策树的参数(如max_depth)和应用剪枝策略,可以有效地控制模型的复杂度,从而改善模型的性能和泛化能力。

补充:可视化

当然,我们还可以进一步对结果进行可视化:

  • 通过matplotlib绘制实际值和预测价格的偏差;
  • 通过plot_tree进行决策树的可视化、
import matplotlib.pyplot as plt
from sklearn.tree import plot_tree
# 可视化预测结果与实际值
plt.figure(figsize=(10, 6))
plt.scatter(y_test, predictions, alpha=0.5, label='Predictions')
plt.plot([min(y_test), max(y_test)], [min(y_test), max(y_test)], 'k--', lw=2, label='Perfect Prediction')
plt.xlabel('Actual')
plt.ylabel('Predicted')
plt.title('Decision Tree Regressor Predictions vs Actual')
plt.legend()
plt.show()

# 使用plot_tree进行决策树可视化
plt.figure(figsize=(20,10))
plot_tree(dtree_reg, 
          feature_names=housing.feature_names, 
          filled=True, 
          rounded=True, 
          max_depth=3) # 限制树的深度为3,使图更简洁
plt.show()

运行后第一幅图的效果如下,其中黑色虚线表示预测值和实际值相等的完美预测线。

【机器学习】决策树_第1张图片

我们使用plot_tree可视化的决策树如下:

【机器学习】决策树_第2张图片

在这个决策树上,我们可以看到:

  • 每个节点的决策条件(例如,某个特征的值是否小于某个阈值)。
  • 每个节点的不纯度(对于回归树,通常是MSE;对于分类树,可能是基尼不纯度或熵)。
  • 每个节点包含的样本数。
  • 每个节点的值(对于回归树,是节点中所有样本目标值的平均值;对于分类树,是每个类别的样本数)。
  • 叶节点的输出值。

我们解释一下其中一个节点(第二行左侧):

  • MedInc <= 3.074: 这是一个分裂条件,表示特征“MedInc”(中位收入)小于等于3.074的样本将被分配到左侧的子节点。
  • squared_error = 0.834: 这是节点内样本的均方误差,一个评估模型在该节点上表现的指标。它衡量的是模型的预测值与实际值之间差异的平方的平均值。
  • samples = 13101: 表示在该节点处有13101个样本。
  • value = 1.74: 对于回归决策树,这通常表示该节点样本目标值的平均值。

通过这种方式可以了解模型的决策过程,包括哪些特征对模型的预测结果有显著影响,以及这些特征是如何影响模型做出最终决策的。

例3:剪枝策略

我们只需要调整sklearn的DecisionTreeRegressor的参数即可进行一个预剪枝的操作。

from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt
from sklearn.tree import plot_tree

# 加载加利福尼亚州房价数据集
housing = fetch_california_housing()
X = housing.data
y = housing.target

# 拆分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 创建决策树回归器实例,应用预剪枝策略
# 限制树的最大深度为3
# 设置节点在分裂前必须拥有的最小样本数为10
# 设置分裂节点时所需的最小增益为0.01
dtree_reg = DecisionTreeRegressor(max_depth=3, min_samples_split=10, min_impurity_decrease=0.01, random_state=42)

# 在训练集上训练决策树回归器
dtree_reg.fit(X_train, y_train)

# 对测试集进行预测
predictions = dtree_reg.predict(X_test)

# 计算均方误差
mse = mean_squared_error(y_test, predictions)
print(f'决策树回归器的均方误差: {mse}')

# 可视化决策树
plt.figure(figsize=(20, 10))
plot_tree(dtree_reg, feature_names=housing.feature_names, filled=True, rounded=True, max_depth=3)
plt.show()

可视化结果如下:

【机器学习】决策树_第3张图片

通常在预剪枝中,可能是因为该节点不满足分裂条件(例如,它的信息增益小于设定的阈值,或者样本数量少于最小分裂样本数等),我们可以发现,在相比原本的实现,这时候的右下角的决策树少了两个节点。

你可能感兴趣的:(机器学习,Python,机器学习,决策树,人工智能)