首先推薦李航的《統計機器學習》這本書,這個實作就是按照書上的算法來的。Python 用的是最新的3.3版的,和2.x不相容,運作的時候需要注意。
'''
Created on 2012-12-18
@author: weisu.yxd
'''
class Node:
'''Represents a decision tree node.
'''
def __init__(self, parent = None, dataset = None):
self.dataset = dataset # 落在該結點的訓練執行個體集
self.result = None # 結果類标簽
self.attr = None # 該結點的分裂屬性ID
self.childs = {} # 該結點的子樹清單,key-value pair: (屬性attr的值, 對應的子樹)
self.parent = parent # 該結點的父親結點
def entropy(props):
if (not isinstance(props, (tuple, list))):
return None
from math import log
log2 = lambda x:log(x)/log(2) # an anonymous function
e = 0.0
for p in props:
e -= p * log2(p)
return e
def info_gain(D, A, T = -1, return_ratio = False):
'''特征A對訓練資料集D的資訊增益 g(D,A)
g(D,A)=entropy(D) - entropy(D|A)
假設資料集D的每個元組的最後一個特征為類标簽
T為目标屬性的ID,-1表示元組的最後一個元素為目标'''
if (not isinstance(D, (set, list))):
return None
if (not type(A) is int):
return None
C = {} # 類别計數字典
DA = {} # 特征A的取值計數字典
CDA = {} # 類别和特征A的不同組合的取值計數字典
for t in D:
C[t[T]] = C.get(t[T], 0) + 1
DA[t[A]] = DA.get(t[A], 0) + 1
CDA[(t[T], t[A])] = CDA.get((t[T], t[A]), 0) + 1
PC = map(lambda x : x / len(D), C.values()) # 類别的機率清單
entropy_D = entropy(tuple(PC)) # map傳回的對象類型為map,需要強制類型轉換為元組
PCDA = {} # 特征A的每個取值給定的條件下各個類别的機率(條件機率)
for key, value in CDA.items():
a = key[1] # 特征A
pca = value / DA[a]
PCDA.setdefault(a, []).append(pca)
condition_entropy = 0.0
for a, v in DA.items():
p = v / len(D)
e = entropy(PCDA[a])
condition_entropy += e * p
if (return_ratio):
return (entropy_D - condition_entropy) / entropy_D
else:
return entropy_D - condition_entropy
def get_result(D, T = -1):
'''擷取資料集D中執行個體數最大的目标特征T的值'''
if (not isinstance(D, (set, list))):
return None
if (not type(T) is int):
return None
count = {}
for t in D:
count[t[T]] = count.get(t[T], 0) + 1
max_count = 0
for key, value in count.items():
if (value > max_count):
max_count = value
result = key
return result
def devide_set(D, A):
'''根據特征A的值把資料集D分裂為多個子集'''
if (not isinstance(D, (set, list))):
return None
if (not type(A) is int):
return None
subset = {}
for t in D:
subset.setdefault(t[A], []).append(t)
return subset
def build_tree(D, A, threshold = 0.0001, T = -1, Tree = None, algo = "ID3"):
'''根據資料集D和特征集A建構決策樹.
T為目标屬性在元組中的索引 . 目前支援ID3和C4.5兩種算法'''
if (Tree != None and not isinstance(Tree, Node)):
return None
if (not isinstance(D, (set, list))):
return None
if (not type(A) is set):
return None
if (None == Tree):
Tree = Node(None, D)
subset = devide_set(D, T)
if (len(subset) <= 1):
for key in subset.keys():
Tree.result = key
del(subset)
return Tree
if (len(A) <= 0):
Tree.result = get_result(D)
return Tree
use_gain_ratio = False if algo == "ID3" else True
max_gain = 0.0
for a in A:
gain = info_gain(D, a, return_ratio = use_gain_ratio)
if (gain > max_gain):
max_gain = gain
attr_id = a # 擷取資訊增益最大的特征
if (max_gain < threshold):
Tree.result = get_result(D)
return Tree
Tree.attr = attr_id
subD = devide_set(D, attr_id)
del(D[:]) # 删除中間資料,釋放記憶體
Tree.dataset = None
A.discard(attr_id) # 從特征集中排查已經使用過的特征
for key in subD.keys():
tree = Node(Tree, subD.get(key))
Tree.childs[key] = tree
build_tree(subD.get(key), A, threshold, T, tree)
return Tree
def print_brance(brance, target):
odd = 0
for e in brance:
print(e, end = ('=' if odd == 0 else '∧'))
odd = 1 - odd
print("target =", target)
def print_tree(Tree, stack = []):
if (None == Tree):
return
if (None != Tree.result):
print_brance(stack, Tree.result)
return
stack.append(Tree.attr)
for key, value in Tree.childs.items():
stack.append(key)
print_tree(value, stack)
stack.pop()
stack.pop()
def classify(Tree, instance):
if (None == Tree):
return None
if (None != Tree.result):
return Tree.result
return classify(Tree.childs[instance[Tree.attr]], instance)
dataset = [
("青年", "否", "否", "一般", "否")
,("青年", "否", "否", "好", "否")
,("青年", "是", "否", "好", "是")
,("青年", "是", "是", "一般", "是")
,("青年", "否", "否", "一般", "否")
,("中年", "否", "否", "一般", "否")
,("中年", "否", "否", "好", "否")
,("中年", "是", "是", "好", "是")
,("中年", "否", "是", "非常好", "是")
,("中年", "否", "是", "非常好", "是")
,("老年", "否", "是", "非常好", "是")
,("老年", "否", "是", "好", "是")
,("老年", "是", "否", "好", "是")
,("老年", "是", "否", "非常好", "是")
,("老年", "否", "否", "一般", "否")
]
T = build_tree(dataset, set(range(0, len(dataset[0]) - 1)))
print_tree(T)
print(classify(T, ("老年", "否", "否", "一般")))
運作結果如下:
2=否∧1=否∧target = 否
2=否∧1=是∧target = 是
2=是∧target = 是
否