天天看點

資料挖掘中決策樹ID3非遞歸算法

最近看了下ID3算法,雖然很經典,但是網上90%的實作方式都是用的遞歸,大家都知道遞歸的效率低下,特别是當資料集很多的時候,剛好最近在看python,是以無聊就把他改成非遞歸的了

資料集:

Sunny Hot High Weak No
Sunny Hot High Strong No
Overcast Hot High Weak Yes
Rain Mild High Weak Yes
Rain Cool Normal Weak Yes
Rain Cool Normal Strong No
Overcast Cool Normal Strong Yes
Sunny Mild High Weak No
Sunny Cool Normal Weak Yes
Rain Mild Normal Weak Yes
Sunny Mild Normal Strong Yes
Overcast Mild High Strong Yes
Overcast Hot Normal Weak Yes
Rain Mild High Strong No
           

代碼:

#encoding:utf-8
__author__ = 'wangjiewen'

import json
from math import log

class ID3(object):
    def __init__(self):
        self.matrix = []
        self.labels = ['Outlook', 'Temperature', 'Humidity', 'Wind']
        self.tree = {}
        self.loadFile('trainset.txt')

    def loadFile(self, filename):
        """
        加載資料檔案
        :param filename: 檔案路徑
        :return: void
        """
        fr = open(filename)
        lines = fr.readlines()

        for lineStr in lines:
            lineArr = lineStr.strip().split(' ')
            self.matrix.append(lineArr)

    def calcEntropy(self, dataSet):
        """
        計算熵 entropy = -∑p(x) * log p(x)
        :param dataSet: 資料矩陣,最後一列為類标号
        :return: float 熵值
        """
        nums = len(dataSet)
        labelsCount = {}

        #統計每個類别的數量
        for featVec in dataSet:
            classLabel = featVec[-1]
            if classLabel not in labelsCount.keys():
                labelsCount[classLabel] = 0
            labelsCount[classLabel] += 1

        entropy = 0.0
        for label in labelsCount:
            prob = float(labelsCount[label]) / nums
            entropy -= prob * log(prob, 2)
        return entropy


    def splitDataSet(self, dataSet, col):
        """
        将dataSet按照第col列分割,傳回第col列值為value的資料
        :param dataSet: []
        :param col: 分割的列
        :return: {} col列所取值的對象數組
        """
        # result = []
        # for featVec in dataSet:
        #     if featVec[col] == value:
        #         tmpLeft = featVec[:col]
        #         tmpLeft.extend(featVec[col + 1:])
        #         result.append(tmpLeft)
        result = {}
        for featVec in dataSet:
            # tmpLine = featVec[:col]
            # tmpLine.extend(featVec[col + 1:])
            tmpLine = featVec[:]

            key = featVec[col]
            if key not in result.keys():
                tmpArr = []
                tmpArr.append(tmpLine)
                result[key] = tmpArr
            else:
                tmpArr = result[key]
                tmpArr.append(tmpLine)


        return result

    def selectMaxGainCol(self, dataSet):
        """
        選擇資訊最大資訊增益的列标号
        :param dataSet:
        :return:
        """
        numsOfCol = len(dataSet[0]) - 1 #資料集的列數,最後一列為類标号
        numsOfData = len(dataSet) #資料集的條目數量

        maxGain = 0.0 #最大的資訊增益
        maxFeatCol = -1 #最大資訊增益對應的列

        #Entropy(S)
        entropyS = self.calcEntropy(dataSet)

        for col in range(0, numsOfCol):
            featDict = self.splitDataSet(dataSet, col)

            tmpGain = entropyS
            for key in featDict.keys():
                featArr = featDict[key]
                entropyFeat = self.calcEntropy(featArr)
                delta = (len(featArr) / float(numsOfData)) * entropyFeat
                tmpGain -= delta

            # print "Gain(", self.labels[col], ") =", tmpGain
            if tmpGain > maxGain:
                maxGain = tmpGain
                maxFeatCol = col
        return maxFeatCol


    def createTree(self):
        """
        建立決策樹
        :return:
        """
        initCol = self.selectMaxGainCol(self.matrix)
        root = {self.labels[initCol]: {}}

        #節點棧,儲存目前通路的節點
        nodeStack = []
        nodeStack.append(root[self.labels[initCol]])

        #類标号棧,儲存目前進行劃分的類标号
        colStack = []
        colStack.append(initCol)

        #資料站,儲存目前需要被劃分的資料,和類标号一一對應
        dataStack = []
        dataStack.append(self.matrix)

        while(len(dataStack) > 0):
            dataSet = dataStack.pop()
            col = colStack.pop()
            pCur = nodeStack.pop() #指向目前節點的指針

            #按屬性進行劃分後的資料字典
            splitDict = self.splitDataSet(dataSet, col)

            for key in splitDict:
                data = splitDict[key]

                #如果全部屬于正類或負類,則标記其類别,代表已經劃分完成
                classSet = set(example[-1] for example in data)
                if len(classSet) == 1:
                    endLabel = classSet.pop()
                    label = self.labels[col]

                    #分情況讨論,當節點具有分支的時候,
                    if label in pCur.keys():
                        pCur[label][key] = endLabel
                    else:
                        pCur[key] = endLabel
                    continue

                #如果屬性的還可以繼續劃分,則将該節點加入對應的棧中
                #因為最後一行為類标号,是以要>1
                if len(data[0]) > 1:
                    tmpMaxCol = self.selectMaxGainCol(data)
                    label = self.labels[tmpMaxCol]
                    pCur[key] = {}
                    pCur[key][label] = {}
                    dataStack.append(data)
                    colStack.append(tmpMaxCol)
                    nodeStack.append(pCur[key])

        print json.dumps(root, indent=4)
        self.tree = root
        return root


    def classify(self, featLabels=[], testVec=[]):
        """
        分類函數預測
        :param testVec:
        :return:
        """
        rootKey = self.tree.keys()[0]
        rootNode = self.tree[rootKey]

        keyStack = []
        keyStack.append(rootKey)
        nodeStack = []
        nodeStack.append(rootNode)

        while len(nodeStack) > 0:
            curKey = keyStack.pop()
            curNode = nodeStack.pop()
            featIndex = featLabels.index(curKey)

            #keyOfAttr是屬性的key, 例如sunny,rainy
            for keyOfAttr in curNode.keys():
                if(testVec[featIndex] == keyOfAttr):

                    #如果節點類型為字典,則不是葉節點,繼續加入棧中
                    if type(curNode[keyOfAttr]).__name__ == 'dict':
                        #nextKey是特征的标号,如Outlook,Humidity
                        nextKey = curNode[keyOfAttr].keys()[0]
                        nextNode = curNode[keyOfAttr][nextKey]
                        keyStack.append(nextKey)
                        nodeStack.append(nextNode)
                    else:
                        classLabel = curNode[keyOfAttr]
                        return classLabel


    def test(self):
        """
        用作測試,沒有什麼用處
        :return:
        """
        dataSet = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
        labels = ['no surfacing', 'flippers']
        self.matrix = dataSet
        self.labels = labels



model = ID3()
# model.test()
tree = model.createTree()
featLabels = ['Outlook', 'Temperature', 'Humidity', 'Wind']
testVec = "Rain Mild High Weak".split(" ")
testVec2 = "Overcast Mild High Weak".split(" ")

result = model.classify(featLabels, testVec)
print result
           

最近看了下ID3算法,雖然很經典,但是網上90%的實作方式都是用的遞歸,大家都知道遞歸的效率低下,特别是當資料集很多的時候,是以無聊就把他改成非遞歸的了

上代碼: