導讀
前期在做一些機器學習的預研工作,對一篇遷移随機森林的論文進行了算法複現,其中需要對sklearn中的決策樹進行繼承和擴充API,這就要求了解決策樹的底層是如何設計和實作的。本文圍繞這一細節加以簡單介紹和分享。

決策樹是一種經典的機器學習算法,先後經曆了ID3、C4.5和CART等幾個主要版本疊代,sklearn中内置的決策樹實作主要是對标CART樹,但有部分原理細節上的差異,關于決策樹的算法原理,可參考曆史文章:暢快!5000字通俗講透決策樹基本原理。決策樹既可用于分類也可實作回歸,同時更是構成了衆多內建算法的根基,是以在機器學習領域有着舉重輕重的作用,關于內建算法,可參考曆史文章:一張圖介紹機器學習中的內建學習算法。
為了探究sklearn中決策樹是如何設計和實作的,以分類決策樹為例,首先看下決策樹都内置了哪些屬性和接口:通過dir屬性檢視一顆初始的決策樹都包含了哪些屬性(這裡過濾掉了以"_"開頭的屬性,因為一般是内置私有屬性),得到結果如下:
上述這些接口中,主要分為兩類:屬性和函數(這貌似說了句廢話:了解程式設計語言中類的定義都知道,類主要是包括屬性和函數的,其中屬性對應取值,函數對應功能實作)。如果需要具體區分哪些是屬性,哪些是函數,可以通過ipython解釋器中的自動補全功能。
大緻浏覽上述結果,屬性主要是決策樹初始化時的參數,例如ccp_alpha:剪枝系數,class_weight:類的權重,criterion:分裂準則等;還有就是決策樹實作的主要函數,例如fit:模型訓練,predict:模型預測等等。
本文的重點是探究決策樹中是如何儲存訓練後的"那顆樹",是以我們進一步用鸢尾花資料集對決策樹進行訓練一下,而後再次調用dir函數,看看增加了哪些屬性和接口:
通過集合的差集,很明顯看出訓練前後的決策樹主要是增加了6個屬性(都是屬性,而非函數功能),其中通過屬性名字也很容易推斷其含義:
- classes_:分類标簽的取值,即y的唯一值集合
- max_features_:最大特征數
- n_classes_:類别數,如2分類或多分類等,即classes_屬性中的長度
- n_features_in_:輸入特征數量,等價于老版sklearn中的n_features_,現已棄用,并推薦n_features_in_
- n_outputs:多輸出的個數,即決策樹不僅可以用于實作單一的分類問題,還可同時實作多個分類問題,例如給定一組人物特征,用于同時判斷其是男/女、胖/瘦和高矮,這是3個分類問題,即3輸出(需要差別了解多分類和多輸出任務)
- tree_:毫無疑問,這個tree_就是今天本文的重點,是在決策樹訓練之後新增的屬性集,其中存儲了決策樹是如何存儲的。
那我們對這個tree_屬性做進一步探究,首先列印該tree_屬性發現,這是一個Tree對象,并給出了在sklearn中的檔案路徑:
我們可以通過help方法檢視Tree類的介紹:
通過上述doc文檔,其中第一句就很明确的對決策樹做了如下描述:
Array-based representation of a binary decision tree.
即:基于數組表示的二分類決策樹,也就是二叉樹!進一步地,在這個二叉樹中,數組的第i個元素代表了決策樹的第i個節點的資訊,節點0表示決策樹的根節點。那麼每個節點又都蘊含了什麼資訊呢?我們注意到上述文檔中列出了節點的檔案名:_tree.pxd,檢視其中,很容易發現節點的定義如下:
雖然是cython的定義文法,但也不難推斷其各屬性字段的類型和含義,例如:
- left_child:size類型(無符号整型),代表了目前節點的左子節點的索引
- right_child:類似于left_child
- feature:size類型,代表了目前節點用于分裂的特征索引,即在訓練集中用第幾列特征進行分裂
- threshold:double類型,代表了目前節點選用相應特征時的分裂門檻值,一般是≤該門檻值時進入左子節點,否則進入右子節點
- n_node_samples:size類型,代表了訓練時落入到該節點的樣本總數。顯然,父節點的n_node_samples将等于其左右子節點的n_node_samples之和。
至此,決策樹中單個節點的屬性定義和實作基本推斷完畢,那麼整個決策樹又是如何将所有節點串起來的呢?我們再次訴諸于訓練後決策樹的tree_屬性,看看它都哪些接口,仍然過濾掉内置私有屬性,得到如下結果:
當然,也可通過ipython解釋器的自動補全功能,進一步檢視各接口是屬性還是函數:
其中很多屬性在前述解釋節點定義時已有提及,這裡需重點關注如下幾個屬性值:
- node_count:該決策樹中節點總數
- children_left:每個節點的左子節點數組
- children_right:每個節點的右子節點數組
- feature:每個節點選用分裂的特征索引數組
- threshold:每個節點選用分裂的特征門檻值數組
- value:落入每個節點的各類樣本數量統計
- n_leaves:葉子節點總數
大概比較重要的就是這些了!為了進一步了解各屬性中的資料是如何存儲的,我們仍以鸢尾花資料集為例,訓練一個max_depth=2的決策樹(根節點對應depth=0),并檢視如下取值:
可知:
- 訓練後的決策樹共包含5個節點,其中3個葉子節點
- 通過children_left和children_right兩個屬性,可以知道第0個節點(也就是根節點)的左子節點索引為1,右子節點索引為2,;第1個節點的左右子節點均為-1,意味着該節點即為葉子節點;第2個節點的左右子節點分别為3和4,說明它是一個内部節點,并做了進一步分裂
- 通過feature和threshold兩個屬性,可以知道第0個節點(根節點)使用索引為3的特征(對應第4列特征)進行分裂,且其最優分割門檻值為0.8;第1個節點因為是葉子節點,是以不再分裂,其對應feature和threshold字段均為-2
- 通過value屬性,可以檢視落入每個節點的各類樣本數量,由于鸢尾花資料集是一個三分類問題,且該決策樹共有5個節點,是以value的取值為一個5×3的二維數組,例如第一行代表落入根節點的樣本計數為[50, 50, 50],第二行代表落入左子節點的樣本計數為[50, 0, 0],由于已經是純的了,是以不再繼續分裂。
- 另外,tree中實際上并未直接标出各葉節點所對應的标簽值,但完全可通過value屬性來得到,即各葉子節點中落入樣本最多的類别即為相應标簽。甚至說,不僅可知道對應标簽,還可通過計算數量之比得到相應的機率!
拿鸢尾花資料集手動驗證一下上述猜想,以根節點的分裂特征3和門檻值0.8進行分裂,得到落入左子節點的樣本計數結果如下,發現确實是分裂後隻剩下50個第一類樣本,也即樣本計數為[50, 0, 0],完全一緻。
另外,通過children_left和children_right兩個屬性的子節點對應關系,其實我們還可以推斷出該二叉樹的周遊方式為前序周遊,即按照根-左-右的順序,對于上述決策樹其分裂後對應二叉樹示意圖如下: