【机器学习|学习笔记】用 Python 结合 graphviz 生成 ID3、C4.5、CART 三种决策树的结构示意图
欢迎铁子们点赞、关注、收藏!
祝大家逢考必过!逢投必中!上岸上岸上岸!upupup
大多数高校硕博生毕业要求需要参加学术会议,发表EI或者SCI检索的学术论文会议论文。详细信息可关注VX “
学术会议小灵通
”或参考学术信息专栏:https://blog.csdn.net/2401_89898861/article/details/147196847
pip install graphviz
Graphviz
软件支持,官网:https://graphviz.org/download/from graphviz import Digraph
def draw_id3_tree():
dot = Digraph(comment='ID3 Tree')
dot.node('A', 'Outlook\n(Info Gain)')
dot.node('B', 'Humidity')
dot.node('C', 'Wind')
dot.node('D', 'No')
dot.node('E', 'Yes')
dot.edges(['AB', 'AC'])
dot.edge('B', 'D', label='High')
dot.edge('B', 'E', label='Normal')
dot.edge('C', 'D', label='Strong')
dot.edge('C', 'E', label='Weak')
dot.render('id3_tree', view=True)
def draw_c45_tree():
dot = Digraph(comment='C4.5 Tree')
dot.node('A', 'Outlook\n(Info Gain Ratio)')
dot.node('B', 'Humidity')
dot.node('C', 'Wind')
dot.node('D', 'No')
dot.node('E', 'Yes')
dot.edges(['AB', 'AC'])
dot.edge('B', 'D', label='High')
dot.edge('B', 'E', label='Normal')
dot.edge('C', 'E', label='Weak')
dot.edge('C', 'D', label='Strong')
dot.render('c45_tree', view=True)
def draw_cart_tree():
dot = Digraph(comment='CART Tree')
dot.node('A', 'Outlook\n(Gini Index)')
dot.node('B', 'Humidity <= 75?')
dot.node('C', 'Wind')
dot.node('D', 'Yes')
dot.node('E', 'No')
dot.edge('A', 'B', label='Sunny')
dot.edge('A', 'C', label='Overcast/Rain')
dot.edge('B', 'D', label='Yes')
dot.edge('B', 'E', label='No')
dot.edge('C', 'D', label='Weak')
dot.edge('C', 'E', label='Strong')
dot.render('cart_tree', view=True)
# 调用绘图函数
draw_id3_tree()
draw_c45_tree()
draw_cart_tree()
id3_tree.pdf
、c45_tree.pdf
、cart_tree.pdf
(视图可根据系统配置自动弹出)。sklearn
训练真实决策树模型,然后导出并用 graphviz
可视化,分别展示 ID3 (用信息增益替代)、C4.5 (近似用信息增益比)、和 CART 的真实结构。(注意:sklearn 原生只支持 CART,ID3 和 C4.5 需自己改写或近似模拟,下面方案做了相应说明和替代)pip install scikit-learn graphviz
sklearn 的 DecisionTreeClassifier
默认就是 CART,支持基尼指数和熵(信息增益)两种分裂标准:
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, export_graphviz
import graphviz
# 载入数据
data = load_iris()
X, y = data.data, data.target
feature_names = data.feature_names
class_names = data.target_names
# 训练 CART (基尼指数)
clf_cart = DecisionTreeClassifier(criterion='gini', random_state=42)
clf_cart.fit(X, y)
# 训练 ID3 近似 (用信息熵作为划分标准)
clf_id3 = DecisionTreeClassifier(criterion='entropy', random_state=42)
clf_id3.fit(X, y)
def export_tree(clf, feature_names, class_names, file_name):
dot_data = export_graphviz(
clf,
out_file=None,
feature_names=feature_names,
class_names=class_names,
filled=True,
rounded=True,
special_characters=True
)
graph = graphviz.Source(dot_data)
graph.render(file_name)
print(f"{file_name}.pdf 已生成")
# 生成 CART 树
export_tree(clf_cart, feature_names, class_names, "cart_iris")
# 生成 ID3 近似树
export_tree(clf_id3, feature_names, class_names, "id3_iris")
sklearn-c4.5
(不过维护有限)或手写实现cart_iris.pdf
:CART 树结构,基尼指数,二叉树id3_iris.pdf
:用信息熵替代的ID3近似树,非严格ID3但演示相似思路from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, export_graphviz
import graphviz
# 载入数据
data = load_iris()
X, y = data.data, data.target
feature_names = data.feature_names
class_names = data.target_names
# 训练 CART (基尼指数)
clf_cart = DecisionTreeClassifier(criterion='gini', random_state=42)
clf_cart.fit(X, y)
# 训练 ID3 近似 (熵)
clf_id3 = DecisionTreeClassifier(criterion='entropy', random_state=42)
clf_id3.fit(X, y)
def export_tree(clf, feature_names, class_names, file_name):
dot_data = export_graphviz(
clf,
out_file=None,
feature_names=feature_names,
class_names=class_names,
filled=True,
rounded=True,
special_characters=True
)
graph = graphviz.Source(dot_data)
graph.render(file_name)
print(f"{file_name}.pdf 已生成")
# 导出决策树
export_tree(clf_cart, feature_names, class_names, "cart_iris")
export_tree(clf_id3, feature_names, class_names, "id3_iris")