
之前我們學習的機器學習算法都是屬于分類算法,也就是預測值是離散值。當預測值為連續值時,就需要使用回歸算法。本文将介紹線性回歸的原理和代碼實作。
線性回歸原理與推導
如圖所示,這時一組二維的資料,我們先想想如何通過一條直線較好的拟合這些散點了?直白的說:盡量讓拟合的直線穿過這些散點(這些點離拟合直線很近)。
目标函數
要使這些點離拟合直線很近,我們需要用數學公式來表示。首先,我們要求的直線公式為:Y = XTw。我們這裡要求的就是這個w向量(類似于logistic回歸)。誤差最小,也就是預測值y和真實值的y的內插補點小,我們這裡采用平方誤差:
求解
我們所需要做的就是讓這個平方誤差最小即可,那就對w求導,最後w的計算公式為:
我們稱這個方法為OLS,也就是“普通最小二乘法”
線性回歸實踐
資料情況
我們首先讀入資料并用matplotlib庫來顯示這些資料。
def loadDataSet(filename):
numFeat = len(open(filename).readline().split('\t')) - 1
dataMat = [];labelMat = []
fr = open(filename)
for line in fr.readlines():
lineArr = []
curLine = line.strip().split('\t')
for i in range(numFeat):
lineArr.append(float(curLine[i]))
dataMat.append(lineArr)
labelMat.append(float(curLine[-1]))
return dataMat, labelMat
回歸算法
這裡直接求w就行,然後對直線進行可視化。
def standRegres(Xarr,yarr):
X = mat(Xarr);y = mat(yarr).T
XTX = X.T * X
if linalg.det(XTX) == 0:
print('不能求逆')
return
w = XTX.I * (X.T*y)
return w
算法優缺點
- 優點:易于了解和計算
- 缺點:精度不高