天天看點

漫談機器學習經典算法—了解EM算法

精美排版版可移步http://lanbing510.info/2015/11/12/Master-EM-Algorithm.html

寫在前面

EM(Expectation Maximization 期望最大化)算法是一種疊代算法,用于含有隐變量的機率模型參數的極大似然估計,或極大後驗機率估計。其每次疊代由E、M兩步構成。下面首先給出一般EM算法的求解過程(怎麼做),然後結合一個例子來了解,然後講為什麼這麼求解,即推導,最後講述EM算法在高斯混合模型中的應用及小結。

EM算法

一般用Y表示觀測随機變量,Z表示隐随機變量,Y和Z在一起稱為完全資料,Y稱為不完全資料。由于含有隐變量,我們不能直接由 P(Y|θ)=∑zP(Z|θ)P(Y|Z,θ) 的最大似然估計來得到模型參數 θ 。EM算法就是在給定Y和Z的聯合機率分布為 P(Y,Z|θ) 的情況下,通過疊代求解 L(θ)=lnP(Y|θ) 的極大似然估計來估算模型參數的算法。求解步驟如下:

1 選擇一個初始的參數 θold 。

2 E Step 估計 P(Z|Y,θold) 。

3 M Step 估計 θnew

θnew=argmaxθQ(θ,θold)

其中

Q(θ,θold)=∑zP(Z|Y,θold)logP(Y,Z|θ)

4 檢查是否到達停止疊代條件,一般是對較小的正數 ε1,ε2 ,若滿足:

∥θnew−θold∥<ε1或∥Q(θnew−θold)−Q(θold−θold)∥<ε2

則停止疊代,否則 θold←θnew 轉到步驟2繼續疊代。

一個栗子

下面結合一個《統計學習方法》中的例子來加深下了解:

例:假設有3枚硬币,分别記做A,B,C。這些硬币正面出現的機率分别是 π , p 和q。進行如下擲硬币實驗:先擲硬币A,根據其結果選出硬币B或C,正面選B,反面選硬币C;然後投擲選重中的硬币,出現正面記作1,反面記作0;獨立地重複n次(n=10),結果為:

1111110000

假設隻能觀察到投擲硬币的結果,而不知其過程,問如何估計三硬币正面出現的機率,即三硬币的模型參數 π , p 和q。

解答:我們現在隻可以看到硬币最終的結果1111110000,即觀測變量Y,至于結果來自于B還是C無從得知,我們設隐變量Z來表示來自于哪個變量,令 θ={π,p,q} 。觀測資料的似然函數可以表示為:

P(Y|θ)=∑zP(Z|θ)P(Y|Z,θ)=∏j=1n[πpyj(1−p)1−yj+(1−π)qyi(1−q)1−yj]

則模型參數 θ 的最大log似然估計為:

θ^=argmaxθlogP(Y|θ)

由于隐變量的存在,使得觀測資料的最大似然函數裡log裡有帶有加和( π..+(1−π).. ),導緻上式是沒有解析解的,隻能通過疊代的方法求解,EM算法就是用于求解此類問題的一種疊代算法。

根據第二部分EM算法求解過程,假設第 i 次疊代參數的估計值為θ(i)=(π(i),p(i),q(i)),EM算法的第i+1次疊代如下:

E步:估計 P(Z|Y,θ(i)) ,即在模型參數 θ(i) 下觀測資料 yj 來自硬币B的機率

μ(i+1)=P(Z|Y,θ(i))=π(i)(p(i))yj(1−p(i))1−yjπ(i)(p(i))yj(1−p(i))1−yj+(1−π(i))(q(i))yj(1−q(i))1−yj

M步:估計 θ(i+1) ,即:

θ(i+1)=argmaxθQ(θ,θi)

其中:

Q(θ,θ(i))=Ez[logP(Y,Z|θ)|Y,θ(i)]=∑zP(Z|Y,θ(i))logP(Y,Z|θ)=∑j=1n[μ(i+1)logπpyj(1−p)1−yj+(1−μ(i+1))log(1−π)qyi(1−q)1−yj]

μ(i+1) 是E步得到常數,可通過分别對參數 π , p 和q求偏導使其為零來最大化上式,獲得 π , p 和q的新的估計值:

π(i+1)=1n∑j=1nμ(i+1)j

p(i+1)=∑nj=1μ(i+1)jyj∑nj=1μ(i+1)j

q(i+1)=∑nj=1(1−μ(i+1)j)yj∑nj=1(1−μ(i+1)j)

在標明參數初始值 θ(0) 後,根據E步,M步循環疊代,直至滿足疊代停止條件,即可得到參數 θ 的極大似然估計。

由上述EM計算過程和結合例子的應用,相信大家都會用EM算法解決問題了,即知道怎麼做了,下面一節主要來講述為什麼這樣做,即為什麼這樣就可以解決此類含有隐變量的最大似然估計。

EM算法的導出

面對一個含有隐變量的機率模型,目标是極大化觀測資料(不完全資料),即極大化

L(θ)=logP(Y|θ)=log∑zP(Y,Z|θ)=log[∑zP(Z|θ)P(Y|Z,θ)]

極大化的主要困難是上式中含有未觀測資料log裡有包含和,EM算法則是通過疊代逐漸近似極大化 L(θ) 的。假設在第i次疊代後 θ 的估計值是 θ(i) ,我們希望新估計的值 θ 能使 L(θ) 增加,即 L(θ)>L(θ(i)) ,并逐漸達到極大值,考慮兩者的差:

L(θ)−L(θ(i))=log[∑zP(Z|θ)P(Y|Z,θ)]−logP(Y|θ(i))=log[∑zP(Z|Y,θ(i))P(Z|θ)P(Y|Z,θ)P(Z|Y,θ(i))]−logP(Y|θ(i))≥∑zP(Z|Y,θ(i))logP(Z|θ)P(Y|Z,θ)P(Z|Y,θ(i))−logP(Y|θ(i))=∑zP(Z|Y,θ(i))logP(Z|θ)P(Y|Z,θ)P(Y|θ(i))P(Z|Y,θ(i))

上式中的不等号是由Jensen不等式得到,下面對Jesen不等式做一個簡單回顧。

Jensen不等式:如果f是凸函數,X是随機變量,則 E[f(X)]≥f(EX) ,特别地,如果f是嚴格凸函數,那麼當 E[f(X)]=f(EX) 當且僅當 P(x=E[X])=1 ,即X為常量。凹函數的性質和凸函數相反。

上式中 f(X) 為 log(X) 為凹函數,則 E[f(X)]≤f(EX) ,是上式中不等号的由來。

繼續轉回正題。令

B(θ,θ(i))≜L(θ(i))+∑zP(Z|Y,θ(i))logP(Z|θ)P(Y|Z,θ)P(Z|θ(i))P(Z|Y,θ(i))

L(θ)≥B(θ,θ(i))

即 B(θ,θ(i)) 是 L(θ) 的下界,且 L(θ(i))=B(θ(i),θ(i))

是以,可以使 B(θ,θ(i)) 增大的 θ 也可以使 L(θ) 增大,為了使 L(θ) 有盡可能大的增長,選擇 θ(i+1) 使得 B(θ,θ(i)) 達到極大,即

θ(i+1)=argmaxθB(θ,θ(i))

省去對 θ 極大化而言是常熟的項,有

θ(i+1)=argmaxθ[L(θ(i))+∑zP(Z|Y,θ(i))logP(Z|θ)P(Y|Z,θ)P(Z|θ(i))P(Z|Y,θ(i))]=argmaxθ[∑zP(Z|Y,θ(i))logP(Z|θ)P(Y|Z,θ)]=argmaxθ[∑zP(Z|Y,θ(i))logP(Y,Z|θ)]=argmaxθQ(θ,θ(i))

推導完畢。

EM算法在高斯混合模型中的應用

EM算法可以用來估計高斯混合模型中的參數,例如:

假設觀測資料 y1,y2,...,yN 由高斯混合模型生成

P(y|θ)=∑k=1Kαkϕ(y|θk)

其中 θ=(α1,α2,...,αk;θ1,θ2,...,θk) 。

可以設想觀測資料 yj,j=1,2,...,N 是這樣産生,先根據機率 αk 選擇第 k 個高斯分布ϕ(y|θk),然後根據第 k 個模型的機率分布ϕ(y|θk)生成觀測資料 yj 。此時觀測資料 yj 是已知的,反映觀測資料 yj 來自哪個分模型是未知的,看作隐變量。可以看出,可以用EM算法估計高斯混合模型(含有隐變量的機率模型參數)的參數 θ 。

1 首先确定E步,估計 P(Z|Y,θ(i)) ,即在已知第 i 次疊代參數的情況下,觀測資料yj來自計算分模型 k 的機率。

γ(i+1)jk=P(Z|Y,θ(i))=α(i)kϕ(y|θ(i)k)∑Kk=1α(i)kϕ(y|θ(i)k),j=1,2,...,N;k=1,2,...,K

2 M步,将新估計的Z的機率代進最大似然的公式,對待估計參數分别求偏導,以計算新一輪疊代模型參數(詳細的推導不再贅述,感興趣的可以自行推導)。

μ(i+1)k=∑Nj=1γ(i+1)jkyj∑Nj=1γ(i+1)jk,k=1,2,...,K

(σ2k)(i+1)=∑Nj=1γ(i+1)jk(yj−μ(i+1)k)2∑Nj=1γ(i+1)jk,k=1,2,...,K

α(i+1)k=∑Nj=1γ(i+1)jkN,k=1,2,...,K

重複E步,M步直至收斂。

小結

1 EM算法是含有隐變量的機率模型極大似然估計或極大後驗機率估計的疊代算法。含有隐變量的機率資料表示為 P(Y,Z|θ) ,Y表示觀測變量,Z是隐變量, θ 是模型參數。EM算法通過疊代求解觀測資料的對數似然函數 L(θ)=log(P|θ) 的極大化來實作極大似然估計。每次疊代包含兩步:

E步,求解 P(Z|Y,θold) ;

M步,估計 θnew

θnew=argmaxθQ(θ,θold)

2 EM算法應用極其廣泛,主要用于含有隐變量的機率模型的學習,但其對參數初始值比較敏感,而且不能保證收斂到全局最優。

參考文獻

[1] 統計學習方法.李航

[2] Pattern Recognition And Machine Learning.Christopher M. Bishop

[3] The EM Algorithm,JerryLead’s Blog

[4] 三硬币問題

繼續閱讀