决策树 算法 实例

#-*- coding: utf-8 -*- 

''' Created on Oct 12, 2010 Decision Tree Source Code for Machine Learning in Action Ch. 3 @author: Peter Harrington '''
from math import log      
import operator

#训练数据集
def createDataSet():
    #类别值有两个:yes,no
    dataSet = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    #两个特征名称
    labels = ['no surfacing','flippers']
    #change to discrete values
    return dataSet, labels

#计算熵
def calcShannonEnt(dataSet):    
    #样本总数
    numEntries = len(dataSet) 
    #dict字典数据类型,字典是由键对值组组成
    labelCounts = {} 
    #featVec遍历dataSet的每一行
    for featVec in dataSet:    
        #每行最后一列为类别标签
        currentLabel = featVec[-1]
        #***统计各类别中的样本数
        #为所有可能取值建立字典<key,value>结构
        #key表示类别,value表示出现次数
        if currentLabel not in labelCounts.keys(): 
           #当前key不在字典中,扩展字典
           labelCounts[currentLabel] = 0  
        #当前类别出现一次,字典value值加1
        labelCounts[currentLabel] += 1 
    #保存信息熵
    shannonEnt = 0.0
    #样本遍历完后,计算给类别占总样本数的比例
    for key in labelCounts:
        #各类别所占比例
        prob = float(labelCounts[key])/numEntries 
        #熵计算,自然对数以2为底
        shannonEnt -= prob * log(prob,2)     
    return shannonEnt

#*****按照给定特征值划分数据集 
#dataSet:待划分的数据集
#axis:特征索引,按照该特征进行划分
#value:所选特征下的某个取值 
def splitDataSet(dataSet, axis, value):
    #定义新变量,保存划分后的数据集
    retDataSet = []
    #遍历数据集dataSet的每一行(条)数据
    for featVec in dataSet:
        #第axis所对应的特征值为要用于划分的特征值
        if featVec[axis] == value:
            #取特征列[0]~[axis-1]
            reducedFeatVec = featVec[:axis]
            #取特征列[axis+1]~[最后一列]
            reducedFeatVec.extend(featVec[axis+1:]) 
            #裁剪掉 [axis]列的特征值
            retDataSet.append(reducedFeatVec)           
    return retDataSet

#*****选择最好的特征划分数据集,即选信息增益最大的特征划分
#dataSet:待划分的数据集 
def chooseBestFeatureToSplit(dataSet):
    #特征维数
    numFeatures = len(dataSet[0]) - 1  
    #数据集的整体熵
    baseEntropy = calcShannonEnt(dataSet)  
    #保存最大信息增益,初始化为0
    bestInfoGain = 0.0; 
    #信息增益最大的特征,初始化为-1列特征
    bestFeature = -1   
    #遍历所有特征
    for i in range(numFeatures):  
        #取第i列特征
        featList = [example[i] for example in dataSet] 
        #第i列特征下的所有不重复取值
        uniqueVals = set(featList)
        #条件熵,初始化为0
        newEntropy = 0.0
        #遍历某个特征下的所有值
        for value in uniqueVals:
            #将dataSet按照第i列特征==value划分成子数据集subDataSet
            subDataSet = splitDataSet(dataSet, i, value)
            #按第i列特征划分后的各子集在总数据集dataSet所占的比例
            prob = len(subDataSet)/float(len(dataSet))
            #按照第i列特征划分后的条件熵
            newEntropy += prob * calcShannonEnt(subDataSet)  
        #信息增益=整体熵-条件熵
        infoGain = baseEntropy - newEntropy 
        #找最大信息增益及所对应的特征
        if (infoGain > bestInfoGain):            
            bestInfoGain = infoGain              
            bestFeature = i  #最大信息增益对应的特征
    return bestFeature                           
#所有特征用完后,以数据集中类别最多的类别作为最终类别
#classList:类别
def majorityCnt(classList):
    #<key,value>字典结构,key存类别,value存类别对应的样本数
    classCount={}
    #遍历数据集中的类别
    for vote in classList:
        #若类别不在字典中,则添加到字典中去
        if vote not in classCount.keys(): classCount[vote] = 0
        #若某类别值出现一次,则累加1,即统计某类别下的样本数
        classCount[vote] += 1
    #按照classCount[vote]值从大到小排序(即各类别下样本数从大到小排序)
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
    #返回含样本数最多的类别
    return sortedClassCount[0][0]

#*****递归创建决策树
#dataSet:数据集
#labels:特征名称 
def createTree(dataSet,labels):
    #获取数据集最后一列,即类别标签
    classList = [example[-1] for example in dataSet]
    #*当划分的数据集属于同一类别则停止划分,返回该类别
    if classList.count(classList[0]) == len(classList): 
        return classList[0] 
    #*划分的数据集已经没有特征值,返回出现次数多的类别
    if len(dataSet[0]) == 1:  
        return majorityCnt(classList)
    #*递归未终止
    #选出最佳划分的特征所对应的索引,即信息增益最大的划分特征
    bestFeat = chooseBestFeatureToSplit(dataSet)
    #最佳划分的特征名称
    bestFeatLabel = labels[bestFeat]
    #将该特征名称作为根节点
    myTree = {bestFeatLabel:{}}
    #删除在原特征名称中的最佳划分特征名称
    del(labels[bestFeat])
    #取出最佳的划分特征列
    featValues = [example[bestFeat] for example in dataSet]
    #特征列下的不重复特征值(去重)
    uniqueVals = set(featValues)
    #遍历最佳划分特征下的特征值
    for value in uniqueVals:
        #获取删除最佳划分特征名称后的特征名称集合
        subLabels = labels[:] 
        #对划分后的子集进行递归调用构建决策树,将递归调用的结果作为树的一个分支
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
    return myTree 

#*****使用决策树进行分类
#inputTree:训练好的决策树
#featLabels:特征的名称
#testVec:测试数据 
def classify(inputTree,featLabels,testVec):
    #存放决策树的根节点名称(比如no surfacing)
    firstStr = inputTree.keys()[0]
    print 'firstStr=',firstStr
    print 'featLabels=',featLabels
    #除根节点名称外的值
    #即{0:'no',1:{'flippers':{0:'no',1:'yes'}}}
    secondDict = inputTree[firstStr]
    print 'secondDict=',secondDict
    #index方法查找当前列表中第一个匹配firstStr变量的元素的索引
    #即找到树根节点在所有特征列的第几列
    featIndex = featLabels.index(firstStr)
    print 'featIndex=',featIndex

    #测试数据对应根节点所在特征下的特征值
    key = testVec[featIndex]
    #secondDict[0]='no'secondDict[1]={'flippers':{0:'no',1:'yes'}}
    valueOfFeat = secondDict[key]
    #判断valueOfFeat的类型
    #valueOfFeat为dict字典类型,递归寻找
    if isinstance(valueOfFeat, dict): 
        classLabel = classify(valueOfFeat, featLabels, testVec)
    #valueOfFeat为数值,直接返回该值(最终类别)
    #此处valueOfFeat=secondDict[key]='no'返回no
    else: classLabel = valueOfFeat
    #返回最终类别
    return classLabel

def storeTree(inputTree,filename):
    import pickle
    fw = open(filename,'w')
    pickle.dump(inputTree,fw)
    fw.close()

def grabTree(filename):
    import pickle
    fr = open(filename)
    return pickle.load(fr)



#==================================================================
####绘图:绘构建的树
# import matplotlib.pyplot as plt
#
# decisionNode = dict(boxstyle="sawtooth", fc="0.8")
# leafNode = dict(boxstyle="round4", fc="0.8")
# arrow_args = dict(arrowstyle="<-")
#
# def getNumLeafs(myTree):
# numLeafs = 0
# firstStr = myTree.keys()[0]
# secondDict = myTree[firstStr]
# for key in secondDict.keys():
# if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
# numLeafs += getNumLeafs(secondDict[key])
# else: numLeafs +=1
# return numLeafs
#
# def getTreeDepth(myTree):
# maxDepth = 0
# firstStr = myTree.keys()[0]
# secondDict = myTree[firstStr]
# for key in secondDict.keys():
# if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
# thisDepth = 1 + getTreeDepth(secondDict[key])
# else: thisDepth = 1
# if thisDepth > maxDepth: maxDepth = thisDepth
# return maxDepth
#
# def plotNode(nodeTxt, centerPt, parentPt, nodeType):
# createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
# xytext=centerPt, textcoords='axes fraction',
# va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
#
# def plotMidText(cntrPt, parentPt, txtString):
# xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
# yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
# createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
#
# def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
# numLeafs = getNumLeafs(myTree) #this determines the x width of this tree
# depth = getTreeDepth(myTree)
# firstStr = myTree.keys()[0] #the text label for this node should be this
# cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
# plotMidText(cntrPt, parentPt, nodeTxt)
# plotNode(firstStr, cntrPt, parentPt, decisionNode)
# secondDict = myTree[firstStr]
# plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
# for key in secondDict.keys():
# if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
# plotTree(secondDict[key],cntrPt,str(key)) #recursion
# else: #it's a leaf node print the leaf node
# plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
# plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
# plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
# plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
# #if you do get a dictonary you know it's a tree, and the first element will be another dict
#
# def createPlot(inTree):
# fig = plt.figure(1, facecolor='white')
# fig.clf()
# axprops = dict(xticks=[], yticks=[])
# createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks
# #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
# plotTree.totalW = float(getNumLeafs(inTree))
# plotTree.totalD = float(getTreeDepth(inTree))
# plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
# plotTree(inTree, (0.5,1.0), '')
# plt.show()

##def createPlot():
## fig = plt.figure(1, facecolor='white')
## fig.clf()
## createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 
## plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
## plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
## plt.show()

##def retrieveTree(i):
## listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
## {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
## ]
## return listOfTrees[i]
##
###createPlot(thisTree)
####绘图:绘构建的树
#==================================================================




if __name__ == "__main__":
    mydat,labels=createDataSet()
    featlabels=labels 

    mytree=createTree(mydat,labels)

    #需要将初始得到的labels保存备份。
print mytree


#****隐形眼镜数据集
## 
## fr=open('lenses.txt','r')
## lenses=[line.strip().split('\t')for line in fr.readlines()]
## lensesLabels=['age','prescript','astigmatic','tearRate']
## lensesTree=createTree(lenses,lensesLabels)
## print lensesTree
## createPlot(lensesTree) #绘图:构建的树


你可能感兴趣的:(算法)