【scikit-learn】可视化决策树

可视化决策树

模型准备

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn import tree
import graphviz
# 加载数据集
data = load_iris() 
# 转换成.DataFrame形式
df = pd.DataFrame(data.data, columns = data.feature_names)
# 添加品种列
df['Species'] = data.target

# 用数值替代品种名作为标签
target = np.unique(data.target)
target_names = np.unique(data.target_names)
targets = dict(zip(target, target_names))
df['Species'] = df['Species'].replace(targets)

# 提取数据和标签
X = df.drop(columns="Species")
y = df["Species"]
feature_names = X.columns
labels = y.unique()

X_train, test_x, y_train, test_lab = train_test_split(X,y,
                                                 test_size = 0.4,
                                                 random_state = 42)
model = DecisionTreeClassifier(max_depth =3, random_state = 42)
model.fit(X_train, y_train)                                         

DecisionTreeClassifier(max_depth=3, random_state=42)

文字表示

# 以文字形式输出树     
text_representation = tree.export_text(model)
print(text_representation)
|--- feature_2 <= 2.45
|   |--- class: setosa
|--- feature_2 >  2.45
|   |--- feature_3 <= 1.75
|   |   |--- feature_2 <= 5.35
|   |   |   |--- class: versicolor
|   |   |--- feature_2 >  5.35
|   |   |   |--- class: virginica
|   |--- feature_3 >  1.75
|   |   |--- feature_2 <= 4.85
|   |   |   |--- class: virginica
|   |   |--- feature_2 >  4.85
|   |   |   |--- class: virginica

plot tree

# 用图片画出
plt.figure(figsize=(30,10), facecolor ='g') #
a = tree.plot_tree(model,
                   feature_names = feature_names,
                   class_names = labels,
                   rounded = True,
                   filled = True,
                   fontsize=14)
plt.show()  

【scikit-learn】可视化决策树_第1张图片

使用graphviz

使用graphviz时,先要下载安装包,链接: 下载地址,安装完之后把bin目录添加到环境变量中去;
然后在相应的conda环境里安装graphviz,使用pip install graphviz或者conda install graphviz
重启一下电脑,然后使用测试用例测试一下

from graphviz import Digraph
 
g = Digraph('测试图片')
g.node(name='a',color='red')
g.node(name='b',color='blue')
g.edge('a','b',color='green')
g.view()

运行后会产生一个由a指向b的图
【scikit-learn】可视化决策树_第2张图片
如果运行报错出现问题先重启电脑试一下

# DOT data
dot_data = tree.export_graphviz(model, out_file=None, 
                                feature_names=data.feature_names,  
                                class_names=data.target_names,
                                filled=True)

# Draw graph
graph = graphviz.Source(dot_data, format="png") 
graph

【scikit-learn】可视化决策树_第3张图片

你可能感兴趣的:(python,决策树,scikit-learn,机器学习)