天天看點

ML之DT:基于簡單回歸問題訓練決策樹(DIY資料集+七種{1~7}深度的決策樹{依次進行10交叉驗證})

輸出結果

ML之DT:基于簡單回歸問題訓練決策樹(DIY資料集+七種{1~7}深度的決策樹{依次進行10交叉驗證})

設計思路

ML之DT:基于簡單回歸問題訓練決策樹(DIY資料集+七種{1~7}深度的決策樹{依次進行10交叉驗證})

核心代碼

for iDepth in depthList:

   for ixval in range(nxval):

       idxTest = [a for a in range(nrow) if a%nxval == ixval%nxval]

       idxTrain = [a for a in range(nrow) if a%nxval != ixval%nxval]

       xTrain = [x[r] for r in idxTrain]

       xTest = [x[r] for r in idxTest]

       yTrain = [y[r] for r in idxTrain]

       yTest = [y[r] for r in idxTest]

       treeModel = DecisionTreeRegressor(max_depth=iDepth)

       treeModel.fit(xTrain, yTrain)

       treePrediction = treeModel.predict(xTest)

       error = [yTest[r] - treePrediction[r] for r in range(len(yTest))]

       if ixval == 0:

           oosErrors = sum([e * e for e in error])

       else:

           oosErrors += sum([e * e for e in error])

   mse = oosErrors/nrow

   xvalMSE.append(mse)

繼續閱讀