前一篇文章主要讲述了第三章的前部分,主要是决策树的算法实现部分
这一部分,因为篇幅较长,我们单独将他们分出来。包括:决策树可视化和分类预测
=======================================================================
matplotlib的文本注释绘制树节点
使用了Matplotlib中的annotate工具
import matplotlib.pyplot as plt
decisionNode=dict(boxstyle="sawtooth",fc="0.8")
leafNode=dict(boxstyle="round4",fc="0.8")
arrow_args=dict(arrowstyle="<-")
#设置三种类型
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
#四个参数应该对应的的是:节点文本、中心坐标、父节点坐标、节点类型
createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',xytext=centerPt,\
textcoords='axes fraction',va="center",ha="center",bbox=nodeType,\
arrowprops=arrow_args)
def createPlot():
fig=plt.figure(1,facecolor='white')
fig.clf()
createPlot.ax1=plt.subplot(111,frameon=False)
plotNode(U'决策节点',(0.5,0.1),(0.1,0.5),decisionNode)
plotNode(U'叶节点',(0.8,0.1),(0.3,0.8),leafNode)
#绘制两个节点
plt.show()
节点文本没有显示出来
def getNumLeafs(myTree):
numLeafs=0
firstStr=list(myTree.keys())[0]
#对于本例中的存储方式,key如果是值格式则为叶子节点,若是字典则为某子树的根节点
#可以以递归的方式求二叉树叶节点
secondDict=myTree[firstStr]
for key in list(secondDict.keys()):
if isinstance(secondDict[key],dict):
numLeafs+=getNumLeafs(secondDict[key])
#如果不是叶节点,则以该节点为根节点的子树进入循环
else:
numLeafs+=1
return numLeafs
在此有个地方跟书中不一样:myTree.keys()[0]返回的并不是列表行的不能有Index,所以改为了list(myTree.keys())[0]
def getTreeDepth(myTree):
maxDepth=0
firstStr=list(myTree.keys())[0]
secondDict=myTree[firstStr]
for key in list(secondDict.keys()):
if isinstance(secondDict[key],dict):
thisDepth=1+getTreeDepth(secondDict[key])
#对于非叶子节点,进一步迭代求其深度
else:
thisDepth=1
#叶子节点
if thisDepth>maxDepth:maxDepth=thisDepth
#遍历该层的每个节点,选出该层节点中的最大层数
return maxDepth
def retrieveTree(i):
listOfTrees=[{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},{'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}]
return listOfTrees[i]
def plotMidText(cntrPt,parentPt,txtString):
xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]
yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]
createPlot.ax1.text(xMid,yMid,txtString)
def plotTree(myTree,parentPt,nodeTxt):
numLeafs=getNumLeafs(myTree)
depth=getTreeDepth(myTree)
firstStr=list(myTree.keys())[0]
cntrPt=(plotTree.xoff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yoff)
plotMidText(cntrPt,parentPt,nodeTxt)
plotNode(firstStr,cntrPt,parentPt,decisionNode)
secondDict=myTree[firstStr]
plotTree.yoff=plotTree.yoff-1.0/plotTree.totalD
for key in secondDict.keys():
if isinstance(secondDict[key],dict):
plotTree(secondDict[key],cntrPt,str(key))
else:
plotTree.xoff=plotTree.xoff+1.0/plotTree.totalW
plotNode(secondDict[key],(plotTree.xoff,plotTree.yoff),cntrPt,leafNode)
plotMidText((plotTree.xoff,plotTree.yoff),cntrPt,str(key))
plotTree.yoff=plotTree.yoff+1.0/plotTree.totalD
def createPlot(inTree):
#主函数,调用了plotTree
fig=plt.figure(1,facecolor='white')
fig.clf()
axprops=dict(xticks=[],yticks=[])
createPlot.ax1=plt.subplot(111,frameon=False,**axprops)
plotTree.totalW=float(getNumLeafs(inTree))
plotTree.totalD=float(getTreeDepth(inTree))
plotTree.xoff=-0.5/plotTree.totalW
plotTree.yoff=1.0
plotTree(inTree,(0.5,1.0),'')
plt.show()
def classify(inputTree,featLabels,testVec):
#分类函数输入的参数有:以构建好的决策树、每个特征维对应的标签、用于分类的向量
firstStr=list(inputTree.keys())[0]
secondDict=inputTree[firstStr]
featIndex=featLabels.index(firstStr)
#由特征标签获得其对应的下标,用于取测试向量中的对应值
for key in secondDict.keys():
if testVec[featIndex]==key:
#如果当前分支与测试向量对应维度值相同
if type(secondDict[key]).__name__=='dict':
#如果该分支对应还有子树,则递归进入下一层子树
classLabel=classify(secondDict[key],featLabels,testVec)
else:
#如果对应分支为叶子节点,则所对应的值为分类结果
classLabel=secondDict[key]
return classLabel
import trees
import treePlotter
myDAT,labels=trees.createDataset()
myTree=treePlotter.retrieveTree(0)
print(trees.classify(myTree,labels,[1,0]))
#该例中输入的样本特征为[1,0]
def storeTree(inputTree,filename):
import pickle
fw=open(filename,'wb')
pickle.dump(inputTree,fw)
fw.close()
def grabTree(filename):
import pickle
fr=open(filename,'rb')
return pickle.load(fr)
fw=open(filename,'wb')
和
fr=open(filename,'rb')
将其以Byte的形式存储
用决策树预测隐形眼镜类型:
fr=open('lenses.txt')
lenses=[inst.strip().split('\t') for inst in fr.readlines()]
#读取文本数据并进行处理
lensesLabel=['age','prescript','astigmatic','tearRate']
#该数据集有四个特征维度
lensesTree=trees.createTree(lenses,lensesLabel)
#建立决策树
print(lensesTree)
treePlotter.createPlot(lensesTree)
#画出决策树
============================================================================================================
决策树这一章终于结束啦,后面的内容会越来越难,坚持!加油!