介绍一下matplotlib画文本提示annotate的使用方法
create_plot.ax.annotate(node_text,xy=' ', xycoords='axes fraction',
xytext=' ', textcoords='axes fraction',
va="center", ha="center",
arrowprops=arrow_args, bbox=' ')
xy = (横坐标,纵坐标) 箭头尖端
xytext = (横坐标,纵坐标) 文字的坐标,指的是最左边的坐标
xycoords and textcoords 是坐标xy与xytext的说明
arrowprops = {
facecolor = '颜色',
shrink = '数字' < 1
收缩箭头
}
boxstyle=边框的形状,也就是边框的类型,有['square', 'sawtooth', 'roundtooth', 'rarrow', 'round4, 'round']
fc 边框的灰度级 在0-1范围,1.0白色,0.0黑色
import matplotlib.pyplot as plt
decision_node = dict(boxstyle="sawtooth",fc="0.8")
leaf_node = dict(boxstyle="round4",fc="0.8")
设置线,这个是带有箭头的线
arrow_args = dict(arrowstyle="<-")
def plot_node(node_text,centerPt,parentPt,nodeType):
"""
画节点
:param node_text:显示的内容
:param centerPt: 坐标
:param parentPt: 坐标
:param nodeType: 边框的种类
:return:
"""
create_plot.ax.annotate(node_text,xy=parentPt,xycoords='axes fraction',
xytext=centerPt,textcoords='axes fraction',
va="center", ha="center",
arrowprops=arrow_args,bbox=nodeType)
def get_num_leafs(mytree):
"""
获取叶节点数目和树的层数
"""
num_leafs = 0
first_str = list(mytree.keys())[0]
second_dict=mytree[first_str]
for key in second_dict.keys():
if type(second_dict[key]).__name__=='dict':
num_leafs +=get_num_leafs(second_dict[key])
else:
num_leafs+=1
return num_leafs
def get_tree_depth(mytree):
"""获取树的深度"""
max_depth =0
first_str =list(mytree.keys())[0]
second_dict = mytree[first_str]
for key in second_dict.keys():
if type(second_dict[key]).__name__ == 'dict':
this_depth =1+ get_tree_depth(second_dict[key])
else:
this_depth = 1
if this_depth>max_depth:
max_depth=this_depth
return max_depth
def plot_mid_text(cntrPt,parentPt,txt_sting):
"""画文本 """
xMid = (parentPt[0]+cntrPt[0])/2.0+cntrPt[0]
yMid = (parentPt[1]+cntrPt[1])/2.0 +cntrPt[1]
create_plot.ax.text(xMid,yMid,txt_sting)
def plot_tree(myTree,parentPt,nodeTxt):
"""画树"""
num_leafs=get_num_leafs(myTree)
depth = get_tree_depth(myTree)
first_str = list(myTree.keys())[0]
cntrPt=(plot_tree.x0ff+ (1.0+float(num_leafs))/2.0/plot_tree.totalW,
plot_tree.y0ff )
plot_mid_text(cntrPt,parentPt,nodeTxt)
plot_node(first_str,cntrPt,parentPt,decision_node)
second_dict = myTree[first_str]
plot_tree.y0ff = plot_tree.y0ff - 1.0/plot_tree.totalD
for key in second_dict.keys():
if type(second_dict[key]).__name__ == 'dict':
plot_tree(second_dict[key],cntrPt,str(key))
else:
plot_tree.x0ff =plot_tree.x0ff +1.0/plot_tree.totalW
plot_node(second_dict[key],(plot_tree.x0ff,plot_tree.y0ff),
cntrPt,leaf_node)
plot_mid_text((plot_tree.x0ff,plot_tree.y0ff),cntrPt,str(key))
plot_tree.y0ff = plot_tree.y0ff + 1.0/plot_tree.totalD
def create_plot(inTree):
fig = plt.figure(1,facecolor='white')
fig.clf()
axprops = dict(xticks=[],yticks=[])
create_plot.ax=plt.subplot(111,frameon=False,**axprops)
plot_tree.totalW = float(get_num_leafs(inTree))
plot_tree.totalD = float(get_tree_depth(inTree))
plot_tree.x0ff = -0.5/plot_tree.totalW
plot_tree.y0ff =1.0
plot_tree(inTree,(0.5,1.0),'')
plt.savefig('output.jpg')
plt.show()
if __name__ == '__main__':
myTree={'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}},2:'maybe'}}
create_plot(myTree)