天天看點

Spark MLlib - Decision Tree源碼分析

以決策樹作為開始,因為簡單,而且也比較容易用到,目前的boosting或random forest也是常以其為基礎的

決策樹算法本身參考之前的blog,其實就是貪婪算法,每次切分使得資料變得最為有序

那麼如何來定義有序或無序?

無序,node impurity 

Spark MLlib - Decision Tree源碼分析

對于分類問題,我們可以用熵entropy或gini來表示資訊的無序程度 

對于回歸問題,我們用方差variance來表示無序程度,方差越大,說明資料間差異越大

information gain

用于表示,由父節點劃分後得到子節點,所帶來的impurity的下降,即有序性的增益

Spark MLlib - Decision Tree源碼分析

mlib決策樹的例子

下面直接看個regression的例子,分類的case,差不多,

Spark MLlib - Decision Tree源碼分析
Spark MLlib - Decision Tree源碼分析

還是比較簡單的,

由于是回歸,是以impurity的定義為variance 

maxdepth,最大樹深,設為5 

maxbins,最大的劃分數 

先了解什麼是bin,決策樹的算法就是對feature的取值不斷的進行劃分 

Spark MLlib - Decision Tree源碼分析
Spark MLlib - Decision Tree源碼分析

但如果是有序的,即按老,中,少的序,那麼隻有m-1個,即2種劃分,老|中,少;老,中|少

對于連續的feature,其實就是進行範圍劃分,而劃分的點就是split,劃分出的區間就是bin 

對于連續feature,理論上劃分點是無數的,但是出于效率我們總要選取合适的劃分點 

有個比較常用的方法是取出訓練集中該feature出現過的值作為劃分點, 

但對于分布式資料,取出所有的值進行排序也比較費資源,是以可以采取sample的方式

源碼分析

首先調用,decisiontree.trainregressor,類似調用靜态函數(object decisiontree)

org.apache.spark.mllib.tree.decisiontree.scala

Spark MLlib - Decision Tree源碼分析
Spark MLlib - Decision Tree源碼分析

調用靜态函數train

Spark MLlib - Decision Tree源碼分析
Spark MLlib - Decision Tree源碼分析

可以看到将所有參數封裝到strategy類,然後初始化decisiontree類對象,繼續調用成員函數train

Spark MLlib - Decision Tree源碼分析
Spark MLlib - Decision Tree源碼分析

可以看到,這裡decisiontree的設計是基于randomforest的特例,即單顆樹的randomforest 

是以調用randomforest.train(),最終因為隻有一棵樹,是以取trees(0)

org.apache.spark.mllib.tree.randomforest.scala

重點看下,randomforest裡面的train做了什麼?

Spark MLlib - Decision Tree源碼分析
Spark MLlib - Decision Tree源碼分析

1. decisiontreemetadata.buildmetadata

org.apache.spark.mllib.tree.impl.decisiontreemetadata.scala

這裡生成一些後面需要用到的metadata 

最關鍵的是計算每個feature的bins和splits的數目,

計算bins的數目

Spark MLlib - Decision Tree源碼分析
Spark MLlib - Decision Tree源碼分析

其他case,bins數目等于feature的numcategories 

對于unordered情況,比較特殊,

Spark MLlib - Decision Tree源碼分析
Spark MLlib - Decision Tree源碼分析

根據bins數目,計算splits

Spark MLlib - Decision Tree源碼分析
Spark MLlib - Decision Tree源碼分析

2. decisiontree.findsplitsbins

首先找出每個feature上可能出現的splits和相應的bins,這是後續算法的基礎 

這裡的注釋解釋了上面如何計算splits和bins數目的算法

a,對于連續資料,對于一個feature,splits = numbins - 1;上面也說了對于連續值,其實splits可以無限的,如何找到numbins - 1個splits,很簡單,這裡用sample 

b,對于離散資料,兩個case 

Spark MLlib - Decision Tree源碼分析

    b.2, 有序的feature,用于regression,二進制分類,或high-arity的多元分類,這種case下劃分的可能比較少,m-1,是以用每個category作為劃分

Spark MLlib - Decision Tree源碼分析
Spark MLlib - Decision Tree源碼分析

3. treepoint和baggedpoint

treepoint是labeledpoint的内部資料結構,這裡需要做轉換,

Spark MLlib - Decision Tree源碼分析
Spark MLlib - Decision Tree源碼分析

arr是findbin的結果, 

這裡主要是針對連續特征做處理,将連續的值通過二分查找轉換為相應bin的index 

對于離散資料,bin等同于featurevalue.toint

baggedpoint,由于random forest是比較典型的bagging算法,是以需要對訓練集做bootstrap sample 

而對于decision tree是特殊的單根random forest,是以不需要做抽樣 

baggedpoint.converttobaggedrddwithoutsampling(treeinput) 

其實隻是做簡單的封裝

4. decisiontree.findbestsplits

這段代碼寫的有點複雜,尤其和randomforest混雜一起

總之,關鍵在

看看binstobestsplit的實作,為了清晰一點,我們隻看continuous feature

四個參數,

binaggregates: dtstatsaggregator, 就是impurityaggregator,給出如果算出impurity的邏輯 

splits: array[array[split]], feature對應的splits 

featuresfornode: option[array[int]], tree node對應的feature  

node: node, 哪個tree node

傳回值,

(split, informationgainstats, predict), 

split,最優的split對象(包含featureindex和splitindex) 

informationgainstats,該split産生的gain對象,表明産生多少增益,多大程度降低impurity 

predict,該節點的預測值,對于連續feature就是平均值,看後面的分析

Spark MLlib - Decision Tree源碼分析
Spark MLlib - Decision Tree源碼分析

predict,這個需要分析一下 

predictwithimpurity.get._1,predictwithimpurity元組的第一個元素 

calculatepredictimpurity的傳回值中的predict

Spark MLlib - Decision Tree源碼分析
Spark MLlib - Decision Tree源碼分析

這裡predict和impurity有什麼不同,可以看出 

impurity = impuritycalculator.calculate() 

predict = impuritycalculator.predict

對于連續feature,我們就看variance的實作,

Spark MLlib - Decision Tree源碼分析
Spark MLlib - Decision Tree源碼分析

從calculate的實作可以看到,impurity求的就是方差, 不是标準差(均方差)

Spark MLlib - Decision Tree源碼分析

繼續閱讀