天天看點

決策樹ID3和C4.5算法Python實作源碼

首先推薦李航的《統計機器學習》這本書,這個實作就是按照書上的算法來的。Python 用的是最新的3.3版的,和2.x不相容,運作的時候需要注意。

'''

Created on 2012-12-18


@author: weisu.yxd
'''


class Node:
    '''Represents a decision tree node.
    
    '''
    def __init__(self, parent = None, dataset = None):
        self.dataset = dataset # 落在該結點的訓練執行個體集
        self.result = None # 結果類标簽
        self.attr = None # 該結點的分裂屬性ID
        self.childs = {} # 該結點的子樹清單,key-value pair: (屬性attr的值, 對應的子樹)
        self.parent = parent # 該結點的父親結點
        


def entropy(props):
    if (not isinstance(props, (tuple, list))):
        return None
    
    from math import log
    log2 = lambda x:log(x)/log(2) # an anonymous function
    e = 0.0
    for p in props:
        e -= p * log2(p)
    return e


def info_gain(D, A, T = -1, return_ratio = False):
    '''特征A對訓練資料集D的資訊增益 g(D,A)
    
    g(D,A)=entropy(D) - entropy(D|A)
            假設資料集D的每個元組的最後一個特征為類标簽
    T為目标屬性的ID,-1表示元組的最後一個元素為目标'''
    if (not isinstance(D, (set, list))):
        return None
    if (not type(A) is int):
        return None
    C = {} # 類别計數字典
    DA = {} # 特征A的取值計數字典
    CDA = {} # 類别和特征A的不同組合的取值計數字典
    for t in D:
        C[t[T]] = C.get(t[T], 0) + 1
        DA[t[A]] = DA.get(t[A], 0) + 1
        CDA[(t[T], t[A])] = CDA.get((t[T], t[A]), 0) + 1
    
    PC = map(lambda x : x / len(D), C.values()) # 類别的機率清單
    entropy_D = entropy(tuple(PC)) # map傳回的對象類型為map,需要強制類型轉換為元組


    PCDA = {} # 特征A的每個取值給定的條件下各個類别的機率(條件機率)
    for key, value in CDA.items():
        a = key[1] # 特征A
        pca = value / DA[a]
        PCDA.setdefault(a, []).append(pca)
    
    condition_entropy = 0.0
    for a, v in DA.items():
        p = v / len(D)
        e = entropy(PCDA[a])
        condition_entropy += e * p
    
    if (return_ratio):
        return (entropy_D - condition_entropy) / entropy_D
    else:
        return entropy_D - condition_entropy
    
def get_result(D, T = -1):
    '''擷取資料集D中執行個體數最大的目标特征T的值'''
    if (not isinstance(D, (set, list))):
        return None
    if (not type(T) is int):
        return None
    count = {}
    for t in D:
        count[t[T]] = count.get(t[T], 0) + 1
    max_count = 0
    for key, value in count.items():
        if (value > max_count):
            max_count = value
            result = key
    return result 


def devide_set(D, A):
    '''根據特征A的值把資料集D分裂為多個子集'''
    if (not isinstance(D, (set, list))):
        return None
    if (not type(A) is int):
        return None
    subset = {}
    for t in D:
        subset.setdefault(t[A], []).append(t)
    return subset


def build_tree(D, A, threshold = 0.0001, T = -1, Tree = None, algo = "ID3"):
    '''根據資料集D和特征集A建構決策樹.
    
    T為目标屬性在元組中的索引 . 目前支援ID3和C4.5兩種算法'''
    if (Tree != None and not isinstance(Tree, Node)):
        return None
    if (not isinstance(D, (set, list))):
        return None
    if (not type(A) is set):
        return None
    
    if (None == Tree):
        Tree = Node(None, D)
    subset = devide_set(D, T)
    if (len(subset) <= 1):
        for key in subset.keys():
            Tree.result = key
        del(subset)
        return Tree
    if (len(A) <= 0):
        Tree.result = get_result(D)
        return Tree
    use_gain_ratio = False if algo == "ID3" else True
    max_gain = 0.0
    for a in A:
        gain = info_gain(D, a, return_ratio = use_gain_ratio)
        if (gain > max_gain):
            max_gain = gain
            attr_id = a # 擷取資訊增益最大的特征
    if (max_gain < threshold):
        Tree.result = get_result(D)
        return Tree
    Tree.attr = attr_id
    subD = devide_set(D, attr_id)
    del(D[:]) # 删除中間資料,釋放記憶體
    Tree.dataset = None
    A.discard(attr_id) # 從特征集中排查已經使用過的特征
    for key in subD.keys():
        tree = Node(Tree, subD.get(key))
        Tree.childs[key] = tree
        build_tree(subD.get(key), A, threshold, T, tree)
    return Tree


def print_brance(brance, target):
    odd = 0 
    for e in brance:        
        print(e, end = ('=' if odd == 0 else '∧'))
        odd = 1 - odd
    print("target =", target)


def print_tree(Tree, stack = []): 
    if (None == Tree):
        return
    if (None != Tree.result):
        print_brance(stack, Tree.result)
        return
    stack.append(Tree.attr)
    for key, value in Tree.childs.items():
        stack.append(key)
        print_tree(value, stack)
        stack.pop()
    stack.pop()
    
def classify(Tree, instance):
    if (None == Tree):
        return None
    if (None != Tree.result):
        return Tree.result
    return classify(Tree.childs[instance[Tree.attr]], instance)
     
dataset = [
   ("青年", "否", "否", "一般", "否")
   ,("青年", "否", "否", "好", "否")
   ,("青年", "是", "否", "好", "是")
   ,("青年", "是", "是", "一般", "是")
   ,("青年", "否", "否", "一般", "否")
   ,("中年", "否", "否", "一般", "否")
   ,("中年", "否", "否", "好", "否")
   ,("中年", "是", "是", "好", "是")
   ,("中年", "否", "是", "非常好", "是")
   ,("中年", "否", "是", "非常好", "是")
   ,("老年", "否", "是", "非常好", "是")
   ,("老年", "否", "是", "好", "是")
   ,("老年", "是", "否", "好", "是")
   ,("老年", "是", "否", "非常好", "是")
   ,("老年", "否", "否", "一般", "否")
]


T = build_tree(dataset, set(range(0, len(dataset[0]) - 1)))
print_tree(T)
print(classify(T, ("老年", "否", "否", "一般")))
           

運作結果如下:

2=否∧1=否∧target = 否

2=否∧1=是∧target = 是

2=是∧target = 是