import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn import preprocessing
from sklearn import tree
from sklearn.feature_extraction import DictVectorizer
import csv
def read_data(filepath):
Dtree = open(filepath,'r')
reader= csv.reader(Dtree)
headers = reader.__next__()
list_data=[]
list_label=[]
for r in reader:
dict_row={}
for i in range(1,len(headers)-1):
dict_row[headers[i]]=r[i]
list_data.append(dict_row)
list_label.append(r[-1])
return list_data,list_label
def trans_data(data):
vec = DictVectorizer()
data = vec.fit_transform(data).toarray()
vec_name = (vec.get_feature_names())
return (data,vec_name)
def trans_label(data):
lb = preprocessing.LabelBinarizer()
data = lb.fit_transform(data)
lb_name=lb.classes_
return (data,lb_name)
def main():
print ("start".center(10,"-"))
filepath= "data.csv"
print ("ing".center(10,"-"))
data,label =read_data(filepath)
x_data,vec_name = trans_data(data)
y_data,lb_name = trans_label(label)
model = tree.DecisionTreeClassifier(criterion="entropy")
model.fit(x_data,y_data)
test = x_data[0]
test_data = (test[np.newaxis,:])
#test_data = test.reshape(1,-1)
print("test:",test_data)
predict = model.predict(test_data)
print("result:",predict)
print ("end".center(10,"-"))
import graphviz
dot_data = tree.export_graphviz(model,out_file=None,feature_names = vec_name,class_names=lb_name,filled=True,rounded=True,special_characters=True)
graph = graphviz.Source(dot_data)
graph.render("comuter")
main()
说明:首先,安装模块graphviz, pip intstall graphviz
其次,下载graphviz,https://www.graphviz.org/download/
之后,配置环境变量,将bin目录配置到系统path环境变量里.
最后,进入cmd,dot -c ,dot -v 测试。
成功后,运行代码即可。
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn import preprocessing
from sklearn import tree
from sklearn.feature_extraction import DictVectorizer
import csv
def read_data(filepath):
data = np.genfromtxt(filepath,delimiter=',')
x_data =data[1:,1:-1]
y_data =data[1:,-1]
return x_data,y_data
def main():
print ("start".center(10,"-"))
filepath= "cart.csv"
print ("ing".center(10,"-"))
x_data,y_data =read_data(filepath)
print ("x_data={}\ny_data={}.".format(x_data,y_data))
model = tree.DecisionTreeClassifier()#基尼系数
model.fit(x_data,y_data)
import graphviz
dot_data = tree.export_graphviz(model,out_file=None,feature_names = ["house_yes","house_no","single","married","divorced","income"],class_names=["no","yes"],filled=True,rounded=True,special_characters=True)
graph = graphviz.Source(dot_data)
graph.render("cart")
print ("end".center(10,"-"))
main()