以决策树作为开始,因为简单,而且也比较容易用到,当前的boosting或random forest也是常以其为基础的
决策树算法本身参考之前的blog,其实就是贪婪算法,每次切分使得数据变得最为有序
那么如何来定义有序或无序?
无序,node impurity

对于分类问题,我们可以用熵entropy或gini来表示信息的无序程度
对于回归问题,我们用方差variance来表示无序程度,方差越大,说明数据间差异越大
information gain
用于表示,由父节点划分后得到子节点,所带来的impurity的下降,即有序性的增益
mlib决策树的例子
下面直接看个regression的例子,分类的case,差不多,
还是比较简单的,
由于是回归,所以impurity的定义为variance
maxdepth,最大树深,设为5
maxbins,最大的划分数
先理解什么是bin,决策树的算法就是对feature的取值不断的进行划分
但如果是有序的,即按老,中,少的序,那么只有m-1个,即2种划分,老|中,少;老,中|少
对于连续的feature,其实就是进行范围划分,而划分的点就是split,划分出的区间就是bin
对于连续feature,理论上划分点是无数的,但是出于效率我们总要选取合适的划分点
有个比较常用的方法是取出训练集中该feature出现过的值作为划分点,
但对于分布式数据,取出所有的值进行排序也比较费资源,所以可以采取sample的方式
源码分析
首先调用,decisiontree.trainregressor,类似调用静态函数(object decisiontree)
org.apache.spark.mllib.tree.decisiontree.scala
调用静态函数train
可以看到将所有参数封装到strategy类,然后初始化decisiontree类对象,继续调用成员函数train
可以看到,这里decisiontree的设计是基于randomforest的特例,即单颗树的randomforest
所以调用randomforest.train(),最终因为只有一棵树,所以取trees(0)
org.apache.spark.mllib.tree.randomforest.scala
重点看下,randomforest里面的train做了什么?
1. decisiontreemetadata.buildmetadata
org.apache.spark.mllib.tree.impl.decisiontreemetadata.scala
这里生成一些后面需要用到的metadata
最关键的是计算每个feature的bins和splits的数目,
计算bins的数目
其他case,bins数目等于feature的numcategories
对于unordered情况,比较特殊,
根据bins数目,计算splits
2. decisiontree.findsplitsbins
首先找出每个feature上可能出现的splits和相应的bins,这是后续算法的基础
这里的注释解释了上面如何计算splits和bins数目的算法
a,对于连续数据,对于一个feature,splits = numbins - 1;上面也说了对于连续值,其实splits可以无限的,如何找到numbins - 1个splits,很简单,这里用sample
b,对于离散数据,两个case
b.2, 有序的feature,用于regression,二元分类,或high-arity的多元分类,这种case下划分的可能比较少,m-1,所以用每个category作为划分
3. treepoint和baggedpoint
treepoint是labeledpoint的内部数据结构,这里需要做转换,
arr是findbin的结果,
这里主要是针对连续特征做处理,将连续的值通过二分查找转换为相应bin的index
对于离散数据,bin等同于featurevalue.toint
baggedpoint,由于random forest是比较典型的bagging算法,所以需要对训练集做bootstrap sample
而对于decision tree是特殊的单根random forest,所以不需要做抽样
baggedpoint.converttobaggedrddwithoutsampling(treeinput)
其实只是做简单的封装
4. decisiontree.findbestsplits
这段代码写的有点复杂,尤其和randomforest混杂一起
总之,关键在
看看binstobestsplit的实现,为了清晰一点,我们只看continuous feature
四个参数,
binaggregates: dtstatsaggregator, 就是impurityaggregator,给出如果算出impurity的逻辑
splits: array[array[split]], feature对应的splits
featuresfornode: option[array[int]], tree node对应的feature
node: node, 哪个tree node
返回值,
(split, informationgainstats, predict),
split,最优的split对象(包含featureindex和splitindex)
informationgainstats,该split产生的gain对象,表明产生多少增益,多大程度降低impurity
predict,该节点的预测值,对于连续feature就是平均值,看后面的分析
predict,这个需要分析一下
predictwithimpurity.get._1,predictwithimpurity元组的第一个元素
calculatepredictimpurity的返回值中的predict
这里predict和impurity有什么不同,可以看出
impurity = impuritycalculator.calculate()
predict = impuritycalculator.predict
对于连续feature,我们就看variance的实现,
从calculate的实现可以看到,impurity求的就是方差, 不是标准差(均方差)