天天看點

GBDT-回歸樹的建立

GBDT中的樹是回歸樹。

最近才開始看GBDT的内容,發現基礎是回歸樹,抓起統計學習方法(p69 算法5.5)就開始看,發現書上的那些式子很晦澀,翻閱了很多的部落格,大緻了解了回歸樹的建立方法。之後會總結GBDT的回歸和分類模型。

一。算法思想,回歸樹遞歸地将每個區域劃分為兩個子區域并決定每個子區域的輸出值。

二。步驟:

1.周遊標明的變量j,對固定的切分變量j掃描切分點s,選擇使下式最小的最小切分變量j和切分點s。

GBDT-回歸樹的建立

2.用標明的(j,s)劃分出區域R1,R2,并決定它們的輸出值,也就是區域内的均值。

3.重複1,2步驟,直到滿足結束條件

4.生成回歸樹,也可以了解為是分段函數

三。代碼實作:

此代碼是看了此作者的部落格寫成。參考:https://blog.csdn.net/xiaoxiao_wen/article/details/54098015

import numpy as np

#參考 https://blog.csdn.net/xiaoxiao_wen/article/details/54098015

def cart_regression_tree(start,end,y):
    if start!=end:#終止條件
        m = []
        for i in range(start,end):
            c1 = np.average(y[start:i+1])
            c2 = np.average(y[i+1:end+1])#注意這裡需要end+1,要不然最後一個元素取不到
            y1 = y[start:i+1]
            y2 = y[i+1:end+1]
            m.append((sum(pow((y1-c1),2))+sum(pow((y2-c2),2))))

        index = m.index(min(m))+start#注意這裡需要加start,否則會死循環
        print("切分點為:",index)
        print("大于",index,"的輸出值為",np.average(y[start:index+1]))
        print("小于",index,"的輸出值為",np.average(y[index+1:end+1]))
        cart_regression_tree(start,index,y)
        cart_regression_tree(index+1,end,y)
    else:
        return None

if __name__ == '__main__':
    x = np.arange(0,10)
    y = [4.5,4.75,4.91,5.34,5.80,7.05,7.90,8.23,8.70,9.00]
    cart_regression_tree(0,9,y)
           

繼續閱讀