"""
Created on Sat Aug 24 11:14:58 2019
@author:wangtao_zuel
E-mail:[email protected]
决策树CART方法
"""
import numpy as np
import pandas as pd
def loadData(filepath,fileType):
"""
返回矩阵形式的数据,当样本类别为str类型时,应当相应修改样本读取方式
"""
if fileType == 'xlsx':
data = pd.read_excel(filepath)
elif fileType == 'csv':
data = pd.read_csv(filepath)
else:
data = pd.read_csv(filepath,sep='\t',header=None)
data = np.mat(data)
return data
def binSplitDataSet(dataSet,featInd,featVal):
"""
按照特征(序号)、特征值将样本二分,这里统一将小于的部分放在左边(matL)
"""
matL = dataSet[np.nonzero(dataSet[:,featInd] <= featVal)[0],:]
matR = dataSet[np.nonzero(dataSet[:,featInd] > featVal)[0],:]
return matL,matR
def regLeaf(dataSet):
"""
叶节点创建
这里返回分支下的分类平均值,适用于回归情况
"""
return np.mean(dataSet[:,-1])
def maxLeaf(dataSet):
"""
叶节点创建
这类返回最多的分类
"""
results = uniqueCount(dataSet)
return max(results,key=results.get)
def uniqueCount(dataMat):
"""
统计各类别样本个数
注意这里使用的是矩阵类数据,若使用其他类型数据需修改遍历循环部分“dataSet[:,-1].T.tolist()[0]”
"""
results = {}
for sample in dataMat[:,-1].T.tolist()[0]:
if sample not in results:
results[sample] = 0
results[sample] += 1
return results
def regErr(dataSet):
"""
误差计算
这里使用的是平方误差,适合回归情况
"""
var = np.var(dataSet[:,-1])
m = dataSet.shape[0]
err = m*var
return err
def entErr(dataSet):
"""
香农熵计算误差(混乱程度)
"""
results = uniqueCount(dataSet)
sampleNum = dataSet.shape[0]
shannonEnt = 0.0
for key in results:
prob = float(results[key])/sampleNum
shannonEnt -= prob*np.log2(prob)
return shannonEnt
def giniErr(dataSet):
"""
基尼不纯度计算误差(混乱程度)
"""
sampleNum = dataSet.shape[0]
results = uniqueCount(dataSet)
imp = 0.0
for k1 in results:
p1 = float(results[k1])/sampleNum
for k2 in results:
if k1 == k2:
continue
p2 = float(results[k2])/sampleNum
imp += p1*p2
return imp
def chooseBestSplit(dataSet,leafType,errType,ops):
"""
筛选最优分类特征、特征值
"""
minErr = ops[0]
minNum = ops[1]
if len(set(dataSet[:,-1].T.tolist()[0])) == 1:
return None,leafType(dataSet)
m,n = dataSet.shape
basicErr = errType(dataSet)
bestErr = np.inf
bestInd = 0
bestVal = 0
for featInd in range(n-1):
for featVal in set(dataSet[:,featInd].T.tolist()[0]):
matL,matR = binSplitDataSet(dataSet,featInd,featVal)
if (matL.shape[0] < minNum) or (matR.shape[0] < minNum):
continue
newErr = errType(matL) + errType(matR)
if newErr < basicErr:
bestInd = featInd
bestVal = featVal
bestErr = newErr
if (basicErr - bestErr) < minErr:
return None,leafType(dataSet)
matL,matR = binSplitDataSet(dataSet,bestInd,bestVal)
if (matL.shape[0] < minNum) or (matR.shape[0] < minNum):
return None,leafType(dataSet)
return bestInd,bestVal
def creatTree(dataSet,leafType,errType,ops):
"""
递归创建树
"""
spInd,spVal = chooseBestSplit(dataSet,leafType,errType,ops)
if spInd == None:
return spVal
tree = {}
tree['spInd'] = spInd
tree['spVal'] = spVal
matL,matR = binSplitDataSet(dataSet,spInd,spVal)
tree['left'] = creatTree(matL,leafType,errType,ops)
tree['right'] = creatTree(matR,leafType,errType,ops)
return tree
"""
# 后剪枝操作
"""
def isTree(obj):
"""
判断分支下是否为子树,是则返回True
"""
return (type(obj).__name__=='dict')
def getMean(tree):
"""
塌陷处理,返回左右分支的平均值作为上一节点的值
"""
if isTree(tree['left']):
return getMean(tree['left'])
if isTree(tree['right']):
return getMean(tree['right'])
return (tree['left']+tree['right'])/2
def regPrune(tree,testData):
"""
递归后剪枝,需要一定数量的测试集,最好数量和样本集相同
注意这种剪枝方法适合用于结果是连续型数据(按平均值塌陷不太适合分类,因为类别是固定的)
"""
if testData.shape[0] == 0:
return getMean(tree)
if (isTree(tree['left'])) or (isTree(tree['right'])):
lSet,rSet = binSplitDataSet(testData,tree['spInd'],tree['spVal'])
if isTree(tree['left']):
tree['left'] = regPrune(tree['left'],lSet)
if isTree(tree['right']):
tree['right'] = regPrune(tree['right'],rSet)
if (not isTree(tree['left'])) and (not isTree(tree['right'])):
lSet,rSet = binSplitDataSet(testData,tree['spInd'],tree['spVal'])
notMergeErr = sum(np.power(lSet[:,-1]-tree['left'],2)) + sum(np.power(rSet[:,-1]-tree['right'],2))
treeMerge = (tree['left']+tree['right'])/2
mergeErr = sum(np.power(testData[:,-1]-treeMerge,2))
if mergeErr < notMergeErr:
print("Merging!")
return treeMerge
else:
return tree
else:
return tree
def outJudge(dataSet,tree):
"""
遍历判断样本外数据类型
"""
outputData = pd.DataFrame(dataSet)
classResults = []
for ii in range(dataSet.shape[0]):
result = judgeType(dataSet[ii,:].A[0],tree)
classResults.append(result)
outputData['classResults'] = classResults
outputData.to_excel('./data/machine_learning/mytree.xlsx',index=False,encoding='utf-8-sig')
print("样本外数据分类(判断)完成!")
def judgeType(data,tree):
"""
递归判断分类
"""
spInd = tree['spInd']
spVal = tree['spVal']
if data[spInd] <= spVal:
if isTree(tree['left']):
return judgeType(data,tree['left'])
return tree['left']
else:
if isTree(tree['right']):
return judgeType(data,tree['right'])
return tree['right']
def treeCart(trainDataPath,outDataPath='',testDataPath='',leafType=regLeaf,errType=regErr,ops=(1,4),prune=False,fileType='txt'):
"""
主函数,参数含义:
trainDataPath:训练集数据路径
outDataPath:样本外数据路径
testDataPath:测试集数据路径,当需要后剪枝操作时需输入
leafType:创建叶节点方式
errType:误差(混乱程度)计算方式
ops:预剪枝参数,第一个元素表示能忽略的最小误差,第二个元素表示当某分支下样本数小于该元素时,不考虑建立该分支
prune:是否进行后剪枝操作
fileType:训练集、测试集数据类型(xlsx、txt、csv),txt文件需以制表符\t为分割
"""
dataMat = loadData(trainDataPath,fileType)
try:
myTree = creatTree(dataMat,leafType,errType,ops)
if prune:
testData = loadData(testDataPath,fileType)
myTree = regPrune(myTree,testData)
print('决策树构建完成!')
print(myTree)
else:
print('决策树构建完成!')
print(myTree)
if outDataPath != '':
outData = loadData(outDataPath,fileType)
outJudge(outData,myTree)
except:
print("检查是否正确输入参数!")
print('请在函数treeCart中输入叶节点创建方式参数:\n\t1、按平均值创建:leafType=regLeaf\n\t2、按最多样本创建:leafType=maxLeaf')
print('请在treeCart中输入误差计算方式参数:\n\t1、香农熵:errType=entErr\n\t2、基尼不纯度:errType=giniErr\n\t3、平方误差:regErr')
print('请在treeCart中输入预剪枝参数ops:\n\t其中第一个元素表示能忽略的最小误差,第二个元素表示当某分支下样本数小于该元素时,不考虑建立该分支')
print('示例:treeCart(trainDataPath,leafType=regLeaf,errType=regErr,ops=(1,4))')