决策树|matplotlib画决策树

介绍一下matplotlib画文本提示annotate的使用方法

create_plot.ax.annotate(node_text,xy=' ', xycoords='axes fraction',
                xytext=' ', textcoords='axes fraction',
                va="center", ha="center",
                arrowprops=arrow_args, bbox=' ')

xy = (横坐标,纵坐标)  箭头尖端
xytext = (横坐标,纵坐标) 文字的坐标,指的是最左边的坐标
xycoords and textcoords 是坐标xy与xytext的说明
arrowprops = {
facecolor = '颜色',
           shrink = '数字' < 1
收缩箭头
}
boxstyle=边框的形状,也就是边框的类型,有['square', 'sawtooth', 'roundtooth', 'rarrow', 'round4, 'round']
fc 边框的灰度级  在0-1范围,1.0白色,0.0黑色

画决策树

import matplotlib.pyplot as plt

decision_node = dict(boxstyle="sawtooth",fc="0.8")
leaf_node = dict(boxstyle="round4",fc="0.8")
设置线,这个是带有箭头的线
arrow_args = dict(arrowstyle="<-")


def plot_node(node_text,centerPt,parentPt,nodeType):
    """
    画节点
    :param node_text:显示的内容
    :param centerPt: 坐标
    :param parentPt: 坐标
    :param nodeType: 边框的种类
    :return:
    """
    create_plot.ax.annotate(node_text,xy=parentPt,xycoords='axes fraction',
                xytext=centerPt,textcoords='axes fraction',
                va="center", ha="center",
                arrowprops=arrow_args,bbox=nodeType)


def get_num_leafs(mytree):
	"""
	获取叶节点数目和树的层数
	"""
    num_leafs = 0
    first_str = list(mytree.keys())[0]

    second_dict=mytree[first_str]
    for key in second_dict.keys():
        if type(second_dict[key]).__name__=='dict':
            num_leafs +=get_num_leafs(second_dict[key])
        else:
            num_leafs+=1
    return num_leafs


def get_tree_depth(mytree):
	"""获取树的深度"""
    max_depth =0
    first_str =list(mytree.keys())[0]
    second_dict = mytree[first_str]
    for key in second_dict.keys():
        if type(second_dict[key]).__name__ == 'dict':
            this_depth =1+ get_tree_depth(second_dict[key])
        else:
            this_depth = 1
        if this_depth>max_depth:
            max_depth=this_depth
    return max_depth


def plot_mid_text(cntrPt,parentPt,txt_sting):
	"""画文本 """
    xMid = (parentPt[0]+cntrPt[0])/2.0+cntrPt[0]
    yMid = (parentPt[1]+cntrPt[1])/2.0  +cntrPt[1]
    create_plot.ax.text(xMid,yMid,txt_sting)


def plot_tree(myTree,parentPt,nodeTxt):
	"""画树"""
    num_leafs=get_num_leafs(myTree)
    depth = get_tree_depth(myTree)

    first_str = list(myTree.keys())[0]
    cntrPt=(plot_tree.x0ff+ (1.0+float(num_leafs))/2.0/plot_tree.totalW,
            plot_tree.y0ff )
    plot_mid_text(cntrPt,parentPt,nodeTxt)
    plot_node(first_str,cntrPt,parentPt,decision_node)
    second_dict = myTree[first_str]
    plot_tree.y0ff = plot_tree.y0ff - 1.0/plot_tree.totalD

    for key in second_dict.keys():
        if type(second_dict[key]).__name__ == 'dict':
           plot_tree(second_dict[key],cntrPt,str(key))
        else:
            plot_tree.x0ff =plot_tree.x0ff +1.0/plot_tree.totalW
            plot_node(second_dict[key],(plot_tree.x0ff,plot_tree.y0ff),
                      cntrPt,leaf_node)

            plot_mid_text((plot_tree.x0ff,plot_tree.y0ff),cntrPt,str(key))

    plot_tree.y0ff = plot_tree.y0ff + 1.0/plot_tree.totalD


def create_plot(inTree):
    fig = plt.figure(1,facecolor='white')
    fig.clf()
    axprops = dict(xticks=[],yticks=[])
    create_plot.ax=plt.subplot(111,frameon=False,**axprops)
    plot_tree.totalW  = float(get_num_leafs(inTree))
    plot_tree.totalD  = float(get_tree_depth(inTree))
    plot_tree.x0ff = -0.5/plot_tree.totalW
    plot_tree.y0ff =1.0
    plot_tree(inTree,(0.5,1.0),'')
    plt.savefig('output.jpg')
    plt.show()


if __name__ == '__main__':

    myTree={'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}},2:'maybe'}}
    create_plot(myTree)

结果展示

决策树|matplotlib画决策树_第1张图片

你可能感兴趣的:(机器学习)