上節學習了如何從資料集中建立樹,但是字典的表示形式非常不易于了解,而且直接繪制圖形也比較困難,這一節我們将使用 Matplotlib庫 來建立樹形圖。
3.2.1 Matplotlib注解
Matplotlib提供了一個注解工具 annotations,可以在資料圖形上添加文本注釋。
建立一個檔案,命名為 treePlotter.py ,然後輸入:
# -*- coding:utf-8 -*-
import matplotlib.pyplot as plt
# 定義文本框和箭頭格式
decisionNode = dict(boxstyle = "sawtooth", fc = "0.8") # fc 應該是顔色深淺
leafNode = dict(boxstyle = "round4", fc = "0.8")
arrow_args = dict(arrowstyle = "<-")
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
# centerPt 箭頭指向坐标, parentPt 箭頭終點坐标
createPlot.ax1.annotate(nodeTxt, xy = parentPt,\
xycoords = 'axes fraction',\
xytext = centerPt, textcoords = 'axes fraction',\
va = "center", ha = "center", bbox = nodeType, arrowprops = arrow_args)
def createPlot():
fig = plt.figure(1, facecolor = 'white')
fig.clf()
createPlot.ax1 = plt.subplot(111, frameon = False)
plotNode(U'決策節點', (0.5, 0.1), (0.1, 0.5), decisionNode) # U 這裡指的是 utf 編碼
plotNode(U'葉節點', (0.8, 0.1), (0.3, 0.8), leafNode)
plt.show()
建立一個運作檔案 run_treePlotter.py ,輸入:
# run_treePlotter.py
import treePlotter
print '>>> treePlotter.createPlot()'
treePlotter.createPlot()
結果如下:
看起來很不錯的圖檔。這就是繪制樹節點的方法。有個問題在于,字是亂碼,不知道怎麼解決。
是以決定改為英文。
plotNode(U'decisionNode', (0.5, 0.1), (0.1, 0.5), decisionNode) # U 這裡指的是 utf 編碼
plotNode(U'leafNode', (0.8, 0.1), (0.3, 0.8), leafNode)
=========================================================================
3.2.2 構造注解樹
我們雖然有 xy 坐标,但是如何放置所有的樹節點卻是個問題。我們必須知道有多少個葉節點,以便可以正确确定 x 軸的長度;我們還需要知道樹有多少層,以便可以正确确定 y 軸的高度。
這裡我們定義兩個新函數,來擷取葉節點的數目和樹的層數。
def getNumLeafs(myTree):
numLeafs = 0
firstStr = myTree.keys()[0] # dict.keys() 傳回字典的 keys
secondDict = myTree[firstStr]
for key in secondDict.keys():
# 利用 type() 函數測試節點的資料類型是否為字典
if type(secondDict[key]).__name__ == 'dict': # 如果子產品是被導入,__name__的值為子產品名字
numLeafs += getNumLeafs(secondDict[key])
else: numLeafs += 1
return numLeafs
def getTreeDepth(myTree): # 計算周遊過程中遇到判斷節點的個數。
maxDepth = 0
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else: thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth
然後,在 run_treePlotter.py 添加:
# run_treePlotter.py
import treePlotter
print '>>> treePlotter.createPlot()'
treePlotter.createPlot()
print '***************************************\n'
reload(treePlotter)
print '>>> treePlotter.retrieveTree(1)'
print treePlotter.retrieveTree(1)
print '>>> myTree = treePlotter.retrieveTree(0)'
print '>>> treePlotter.getNumLeafs(myTree)'
myTree = treePlotter.retrieveTree(0)
print treePlotter.getNumLeafs(myTree)
print '>>> treePlotter.getTreeDepth(myTree)'
print treePlotter.getTreeDepth(myTree)
結果是:
當然我們還是沒完整把圖畫出來。
現在開始添加繪圖代碼:
def plotMidText(cntrPt, parentPt, txtString): # 在父子節點間填充文本資訊
# 計算父節點和子節點的中間位置
xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString)
def plotTree(myTree, parentPt, nodeTxt): # 計算樹的寬與高
numLeafs = getNumLeafs(myTree)
depth = getTreeDepth(myTree)
firstStr = myTree.keys()[0]
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW,\
plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt) # 标記子節點屬性值
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD # 減少 y 偏移
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key], cntrPt, str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff),\
cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
然後,改動之前的 createPlot(inTree) 函數為:
def createPlot(inTree):
fig = plt.figure(1, facecolor = 'white')
fig.clf()
axprops = dict(xticks = [], yticks = [])
createPlot.ax1 = plt.subplot(111, frameon = False, **axprops)
plotTree.totalW = float(getNumLeafs(inTree)) # 儲存樹的寬度
plotTree.totalD = float(getTreeDepth(inTree)) # 儲存樹的深度
plotTree.xOff = -0.5 / plotTree.totalW; plotTree.yOff = 1.0
plotTree(inTree, (0.5, 1.0), '')
plt.show()
準備運作代碼, run_treePlotter.py 改為:
# run_treePlotter.py
import treePlotter
"""
print '>>> treePlotter.createPlot()'
treePlotter.createPlot()
print '***************************************\n'
reload(treePlotter)
print '>>> treePlotter.retrieveTree(1)'
print treePlotter.retrieveTree(1)
print '>>> myTree = treePlotter.retrieveTree(0)'
print '>>> treePlotter.getNumLeafs(myTree)'
myTree = treePlotter.retrieveTree(0)
print treePlotter.getNumLeafs(myTree)
print '>>> treePlotter.getTreeDepth(myTree)'
print treePlotter.getTreeDepth(myTree)
"""
print '***************************************\n'
reload(treePlotter)
print '>>> myTree = treePlotter.retrieveTree(0)'
myTree = treePlotter.retrieveTree(0)
print '>>> treePlotter.createPlot(myTree)'
print treePlotter.createPlot(myTree)
運作後圖檔為:
沒有坐标軸标簽,我們要在運作檔案裡面添加一些指令,重新繪制樹形圖。
# run_treePlotter.py
import treePlotter
print '>>> myTree = treePlotter.retrieveTree(0)'
myTree = treePlotter.retrieveTree(0)
print ">>> myTree['no surfacing'][3] = maybe" # 增加分支
myTree['no surfacing'][3] = 'maybe'
print '>>> myTree'
print myTree
print '>>> treePlotter.createPlot(myTree)'
print treePlotter.createPlot(myTree)
運作結果: