天天看點

Python 實作決策樹 ID3 C4.5 悲觀剪枝

#!/usr/bin/python
# -*- coding: utf-8 -*-
import math
import random
import operator

'''
zys 2016-10-01
本程式實作的是C4.5(包含ID3的實作)
參考: http://blog.csdn.net/lulei1217/article/details/49583357
過程:
1、建立決策樹
    1)通過計算增量熵查找第一個最優特征分類
    2)根據分類的特征如(0,1)分别去除掉資料中已經确定的分類資料
    3)通過計算新資料增量熵查找下一個最優特征
    4)傳回導第二步驟,直到符合循環終止條件
        條件一:其目前分類特征全部相同
        條件二:所有特征都已經分類
2、進行悲觀剪枝
3、使用決策樹進行測試分類
'''


def createTree(traindata, labels):
    '''建立決策樹'''
    classList = [item[-] for item in traindata]
    '''
    目前特征分類的labels全部相同的時候停止分類,直接傳回任意一個label
    例如:當badrecord(不良記錄) = 1時候,offer全部都為0 這時候直接傳回0
    '''
    if classList.count(classList[]) == len(classList):
        return classList[]
    if len(traindata[]) == :
        return majorityCnt(classList)
    bestFeature = chooseBestFeature(traindata)  # 擷取最優分類特征
    bestFeatLabel = labels[bestFeature]
    tree = {bestFeatLabel: {}}
    featureList = [example[bestFeature] for example in traindata]
    uniqueVals = set(featureList)
    ''''
    循環對後續特征分類
    去除掉已經選擇的特征和資料
    '''
    del(labels[bestFeature])
    for feature in uniqueVals:
        A = getLablesByfeature(traindata, bestFeature, feature)
        tree[bestFeatLabel][feature] = createTree(A, labels[:])
    return tree


def majorityCnt(classList):
    '''當對最後一個特征進行分類時候,直接傳回目前出現最多的labels'''
    classCount = {}
    for vote in classList:
        classCount[vote] = classCount.get(vote, ) + 
    # key=operator.itemgetter(1) 定義函數key,擷取對象的第1個域的值
    sortedClassCount = sorted(classCount.items(),
                              key=operator.itemgetter(), reverse=True)
    return sortedClassCount[][]


def chooseBestFeature(traindata):
    '''
    選擇最優特征
    在ID3中擷取資訊熵的增益:Gain(S,A) = Entropy(S) - ∑(|A| / |S|) * Entropy(A)
    在C4.5中擷取的是資訊熵的增益率:
    SplitInfo(A) = - ∑(|A| / |S|) * log2(|A| / |S|)
    GainRate(S,A) = Gain(S,A) / SplitInfo(A)
    Entropy(S) 指總集合S的資訊資訊熵
    A 是S中的某個屬性的子集
    |S| 指集合S的樣例數data
    '''
    size = len(traindata[]) -    # 擷取特征的數量
    Entropy = calculateEntropy(traindata)  # 計算總集合的資訊熵 Entropy(S)
    GainRate = 
    bestFeature = -
    for i in range(size):
        '''
        擷取目前特征的子集,例如:
        school 對應着[0,0,0,0,1,1,1,1]
        '''
        featureList = [example[i] for example in traindata]
        '''
        擷取目前特征的分類,例如
        school 對應着[0,1] 名校、不是名校
        '''
        uniqueVals = set(featureList)
        newEntropy = 
        splitInfo = 
        for feature in uniqueVals:
            A = getLablesByfeature(traindata, i, feature)
            prob = float(len(A)) / len(traindata)
            # 計算特征子集的資訊熵 Entropy(A)
            newEntropy += prob * calculateEntropy(A)
            info = 
            if(prob != ):
                info = math.log(prob, )
            splitInfo -= prob * info
        newGain = Entropy - newEntropy  # 計算資訊增益
        if (splitInfo == ):  # 修複溢出錯誤
            splitinfo = - * math.log(, ) -  * math.log(, )
        newGain = newGain / splitInfo
        if (newGain > GainRate):
            GainRate = newGain
            bestFeature = i
    return bestFeature

'''
ID3 中的代碼:
def chooseBestFeature(traindata):
    size = len(traindata[0]) - 1
    Entropy = calculateEntropy(traindata)
    Gain = 0.0
    bestFeature = -1
    for i in range(size):
        uniqueVals = set(featureList)
        newEntropy = 0.0
        for feature in uniqueVals:
            A = getLablesByfeature(traindata, i, feature)
            prob = float(len(A)) / len(traindata)
            newEntropy += prob * calculateEntropy(A)
        newGain = Entropy - newEntropy
        if (newGain > Gain):
            Gain = newGain
            bestFeature = i
    return bestFeature
'''


def getLablesByfeature(traindata, index, feature):
    '''
    通過特征來擷取對應的Lables,例如:
    擷取school=0,多對應的Lables [0,0,1,0]
    '''
    A = []
    for item in traindata:
        if item[index] == feature:
            temp = item[:index]  # 抽取除index特征外的所有的記錄的内容
            temp.extend(item[index + :])
            A.append(temp)
    return A


def calculateEntropy(data):
    '''
    計算資訊熵 Entropy=-∑P(ui)*log(P(ui))
    P(ui)是類别ui出現機率
    '''
    labelCount = {}
    for item in data:
        lable = item[-]
        labelCount[lable] = labelCount.get(lable, ) + 
    entropy = 
    for key in labelCount:
        p = float(labelCount[key]) / len(data)
        entropy -= p * math.log(p, )
    return entropy


def classify(tree, lables, item):
    '''
    測試
    tree.keys()找出所有的key,在擷取對應的key值時候需要将其改變為list類型
    '''
    root = list(tree.keys())[]
    nextNode = tree[root]
    index = lables.index(root)
    key = item[index]
    val = nextNode[key]
    '''
    判斷valueOfFeat是否是dict類型
    如果是說明其還有子節
    不是則表示沒有子節點,傳回的值就是其對應的分類
    '''
    if(isinstance(val, dict)):
        classLabel = classify(val, lables, item)
    else:
        classLabel = val
    return classLabel


def getCount(tree, data, lables, count):
    root = list(tree.keys())[]
    nextNode = tree[root]
    index = lables.index(root)
    del(lables[index])
    for key in nextNode.keys():
        rightcount = 
        wrongcount = 
        A = getLablesByfeature(data, index, key)
        # 判斷是否是葉子節點,不是則疊代進入下一層
        if(isinstance(nextNode[key], dict)):
            getCount(nextNode[key], A, lables[:], count)
        else:
            for item in A:
                # 判斷數組給定的分類是否與葉子節點的值相同
                if(str(item[-]) == str(nextNode[key])):
                    rightcount += 
                else:
                    wrongcount += 
            count.append([rightcount, wrongcount])


def cutBranch(tree, data, lables):
    '''
    悲觀剪枝 參考:http://www.jianshu.com/p/794d08199e5e
    old = errorNum + o.5 * L  errorNum:葉子節點錯誤分類的個數,L:葉子節點個數
    p = old / N  N:資料樣本總個數
    new = errorNum + o.5
    S = math.sqrt(N  * p * (1 - p))
    if new <= old - S then 剪枝

    注:都是自己了解的,如果有不對的地方歡迎指出
    '''
    root = list(tree.keys())[]
    nextNode = tree[root]
    index = lables.index(root)
    newTree = {root: {}}
    del(lables[index])
    for key in nextNode.keys():
        # 如果子節點不是葉子節點就判斷其是否滿足剪枝
        if(isinstance(nextNode[key], dict)):
            A = getLablesByfeature(data, index, key)
            count = []
            # 擷取每個葉子節點的(正确分類數,錯誤分類數)
            getCount(nextNode[key], A, lables[:], count)
            allnum = 
            errornum = 
            for i in count:
                allnum += i[] + i[]
                errornum += i[]
            if(errornum == ):
                # 當該子樹不存在錯誤分類的時候,不對該樹進行剪枝操作
                # 進行下個循環
                newTree[root][key] = nextNode[key]
                continue
            old = errornum + len(count) * 
            new = errornum + 
            p = old / allnum
            S = math.sqrt(allnum * p * ( - p))
            if(new <= old - S):
                # 用目前分類時出現最多的lables代替該子樹
                classList = [item[-] for item in A]
                newTree[root][key] = majorityCnt(classList)
            else:
                # 不滿足剪枝則進入其子樹内部繼續進行剪枝操作
                newTree[root][key] = cutBranch(nextNode[key], A, lables[:])
        else:
            newTree[root][key] = nextNode[key]
    return newTree

if(__name__ == "__main__"):
    '''
    -----------------------------開始-----------------------------
    公司校園招聘
    名校    技術能力    不良記錄    是否錄取
    是(1)   厲害(1)     有(1)      錄取(1)
    否(0)   一般(0)     無(0)      不錄取(0)

    data = [[0, 0, 0, 0],
            [0, 0, 1, 0],
            [0, 1, 0, 1],
            [0, 1, 1, 0],
            [1, 0, 0, 1],
            [1, 0, 1, 0],
            [1, 1, 0, 1],
            [1, 1, 1, 0]]
    lables = ["school", "ability", "badrecord", "offer"]
    '''
    # data = [['youth', 'high', 'no', 'fair', 'no'],
    #         ['youth', 'high', 'no', 'excellent', 'no'],
    #         ['middle_aged', 'high', 'no', 'fair', 'yes'],
    #         ['senior', 'medium', 'no', 'fair', 'yes'],
    #         ['senior', 'low', 'yes', 'fair', 'yes'],
    #         ['senior', 'low', 'yes', 'excellent', 'no'],
    #         ['middle_aged', 'low', 'yes', 'excellent', 'yes'],
    #         ['youth', 'medium', 'no', 'fair', 'no'],
    #         ['youth', 'low', 'yes', 'fair', 'yes'],
    #         ['senior', 'medium', 'yes', 'fair', 'yes'],
    #         ['youth', 'medium', 'yes', 'excellent', 'yes'],
    #         ['middle_aged', 'medium', 'no', 'excellent', 'yes'],
    #         ['middle_aged', 'high', 'yes', 'fair', 'yes'],
    #         ['senior', 'medium', 'no', 'excellent', 'no']]
    # lables = ['age', 'income', 'student', 'credit_rating']
    data = [['dark_green', 'curl_up', 'little_heavily', 'distinct', 'sinking', 'hard_smooth', ],
            ['black', 'curl_up', 'heavily', 'distinct', 'sinking', 'hard_smooth', ],
            ['black', 'curl_up', 'little_heavily',
                'distinct', 'sinking', 'hard_smooth', ],
            ['dark_green', 'little_curl_up', 'little_heavily',
             'distinct', 'little_sinking', 'soft_stick', ],
            ['black', 'little_curl_up', 'little_heavily',
             'little_blur', 'little_sinking', 'soft_stick', ],
            ['dark_green', 'stiff', 'clear', 'distinct', 'even', 'soft_stick', ],
            ['light_white', 'little_curl_up', 'heavily',
             'little_blur', 'sinking', 'hard_smooth', ],
            ['black', 'little_curl_up', 'little_heavily',
             'distinct', 'little_sinking', 'soft_stick', ],
            ['light_white', 'curl_up', 'little_heavily',
             'blur', 'even', 'hard_smooth', ],
            ['dark_green', 'curl_up', 'heavily', 'little_blur',
             'little_sinking', 'hard_smooth', ],
            ['dark_green', 'curl_up', 'heavily',
             'distinct', 'sinking', 'hard_smooth', ],
            ['light_white', 'curl_up', 'little_heavily',
             'distinct', 'sinking', 'hard_smooth', ],
            ['black', 'little_curl_up', 'little_heavily',
             'distinct', 'little_sinking', 'hard_smooth', ],
            ['black', 'little_curl_up', 'heavily', 'little_blur',
             'little_sinking', 'hard_smooth', ],
            ['light_white', 'stiff', 'clear', 'blur', 'even', 'hard_smooth', ],
            ['light_white', 'curl_up', 'little_heavily',
             'blur', 'even', 'soft_stick', ],
            ['dark_green', 'little_curl_up', 'little_heavily', 'little_blur', 'sinking', 'hard_smooth', ]]
    lables = ['color', 'root', 'knocks', 'texture', 'navel', 'touch']

    '''
    傳入的是data和lables的複制,因為在函數中會改變該值。
    而python傳參傳遞的是一個位址,是以會改變原本的資料
    '''
    decisiontree = createTree(data[:], lables[:])
    createPlot(decisiontree)
    decisiontree = cutBranch(decisiontree, data, lables[:])
    k = 
    for item in data:
        if(classify(decisiontree, lables[:], item) == item[-]):
            k += 
    print(float(k) / len(data))
    createPlot(decisiontree)
           

使用matplotlib畫出決策樹

import matplotlib.pyplot as plt

'''
使用matplotlib的annotate畫出決策樹
參考: http://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.annotate
matplotlib.pyplot.annotate(text,xy,xytext,xycoords,textcoords,arrowprops)
text:注解的内容
xy:箭頭指向的坐标
xytext:注解内容的坐标,如果沒有預設xy
xycoords:給定xy一個坐标系統
textcoords:給定xytext一個坐标系統
arrowprops:箭頭的格式
'''
decisionNode = dict(boxstyle="round4", color='#3366FF')  # 定義判斷結點形态
leafNode = dict(boxstyle="circle", color='#FF6633')  # 定義葉結點形态
arrow_args = dict(arrowstyle="<-", color='g')  # 定義箭頭


# 繪制帶箭頭的注釋
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                            xytext=centerPt, textcoords='axes fraction',
                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)


# 計算葉結點數
def getNumLeafs(myTree):
    numLeafs = 
    firstStr = list(myTree.keys())[]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 
    return numLeafs


# 計算樹的層數
def getTreeDepth(myTree):
    maxDepth = 
    firstStr = list(myTree.keys())[]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth =  + getTreeDepth(secondDict[key])
        else:
            thisDepth = 
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth


# 在父子結點間填充文本資訊相同
def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[] - cntrPt[]) /  + cntrPt[]
    yMid = (parentPt[] - cntrPt[]) /  + cntrPt[]
    createPlot.ax1.text(xMid, yMid, txtString, va="center",
                        ha="center", rotation=)


def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[]
    cntrPt = (plotTree.xOff + ( + float(numLeafs)) /
               / plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)  # 在父子結點間填充文本資訊
    plotNode(firstStr, cntrPt, parentPt, decisionNode)  # 繪制帶箭頭的注釋
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff -  / plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff +  / plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff,
                                       plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff +  / plotTree.totalD


def createPlot(inTree):
    fig = plt.figure(, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(, frameon=False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = - / plotTree.totalW
    plotTree.yOff = 
    plotTree(inTree, (, ), '')
    plt.show()
           

效果圖

Python 實作決策樹 ID3 C4.5 悲觀剪枝