天天看点

线性回归学习心得

线性回归学习心得

本文是自己以周志华老师的西瓜书为主要学习媒介,以吴恩达老师的机器学习视频为补充的线性回归学习心得。线性回归是机器学习的入门,虽比较基础但极为重要。

个人觉得,西瓜书的3.2节写得已经十分精彩,我再赘述很难达到周老师的高度。下面也推荐一个博客链接,我觉得他对线性回归的整理也是非常精彩了,本文仅仅是它的一个补充:

https://blog.csdn.net/KevinBetterQ/article/details/83117342

本文的主要精力集中在西瓜书中不够详细的公式推导和满秩矩阵上。

1.式(3.10)的矩阵求导

线性回归学习心得

在西瓜书交流群里,这一步的推导是许多群友大惑不解的。这里的关键是矩阵求导,因为在我们学微积分的时候学了求导没学矩阵求导,讲矩阵的时候也没有讲矩阵求导,所以成为了知识盲点。但是矩阵求导的关键,是要展开。

参考wiki的矩阵求导公式,对实数的求导公式有:

线性回归学习心得
线性回归学习心得

我们先看被求导项,其中每一项矩阵维数为:

y : m × 1 y:m \times 1 y:m×1

x : m × ( d + 1 ) x:m \times (d + 1) x:m×(d+1)

w ^ : ( d + 1 ) × 1 \widehat w:(d + 1) \times 1 w

:(d+1)×1

所以Ew是(1xm维)*(mx1维),最后结果是一个1维实数。实数对一个(d+1)*1维的列向量进行求导,结果依然是一个(d+1)*1维的列向量。

再看Ew的详细展开式:

E w ^ = ( y − X w ^ ) T ( y − X w ^ ) = y T y − w ^ T X T y − y T X w ^ + ( X w ^ ) T X w ^ {{\rm{E}}_{\widehat w}} = {(y - X\widehat w)^T}(y - X\widehat w) = {y^T}y - {\widehat w^T}{X^T}y - {y^T}X\widehat w + {(X\widehat w)^T}X\widehat w Ew

​=(y−Xw

)T(y−Xw

)=yTy−w

TXTy−yTXw

+(Xw

)TXw

式子中第二、三项关系为转置,即:

w ^ T X T y = ( y T X w ^ ) T {\widehat w^T}{X^T}y = {({y^T}X\widehat w)^T} w

TXTy=(yTXw

)T

且根据前面的维度概述,可很容易知道Ew的每一项都是实数。实数的转置=实数本身,故有:

E w ^ = y T y − 2 y T X w ^ + ( X w ^ ) T X w ^ {{\rm{E}}_{\widehat w}} = {y^T}y - 2{y^T}X\widehat w + {(X\widehat w)^T}X\widehat w Ew

​=yTy−2yTXw

+(Xw

)TXw

根据矩阵求导的公式有:

∂ ( A w ) ∂ w = A T {{\partial (Aw)} \over {\partial w}} = {A^T} ∂w∂(Aw)​=AT

∂ ( ( A w ^ ) T A w ^ ) ∂ w ^ = 2 A T A w ^ {{\partial ({{(A\widehat w)}^T}A\widehat w)} \over {\partial \widehat w}} = 2{A^T}A\widehat w ∂w

∂((Aw

)TAw

)​=2ATAw

故得到最终(3.10)的结论:

E w ^ = 2 X T X w ^ − 2 X T y {{\rm{E}}_{\widehat w}} = 2{X^T}X\widehat w - 2{X^T}y Ew

​=2XTXw

−2XTy

2.满秩矩阵

线性回归的最优化系数w要想求解,有多种方法,包括一步到位的正规方程法、梯度下降法等。详细请见参考资料2的博客。本文补充一下满秩矩阵的说明。

矩阵的秩——假设m*n的矩阵的秩为k,它指的是能从矩阵中任意抽取k行k列,位于这些行列交叉处的元素按原来顺序构成k阶行列式(又称k阶子式),其值不为0.而任意阶数>k的子式行列式的值都为0。

正规方程法的求解方式为(3.11)式:

线性回归学习心得

正规方程法能生效的前提,要求矩阵满秩,即矩阵的秩为d+1,d为样本的特征个数。

西瓜书里并没有提到的是下面的细节:

如果是满秩矩阵,那就有唯一解,使用正规方程法能非常方便地求出最优解w,使得均方误差最小化。不过根据吴恩达等老师的视频资料,建议在n<10000时使用正规方程法能得到较小的时间复杂度。由于时间复杂度为o(n^3),n过大时时间复杂度过高,不如梯度下降等方法的性价比。

如果不是满秩矩阵,那就有无穷多个解。此时建议引入L1或者L2正则化项。

参考资料:

1.https://en.wikipedia.org/wiki/Matrix_calculus#Scalar-by-vector_identities

2.https://blog.csdn.net/KevinBetterQ/article/details/83117342

继续阅读