天天看點

從兩個例子了解EM算法從兩個例子了解EM算法

從兩個例子了解EM算法

本文是作者對EM算法學習的筆記,從EM算法出發介紹EM算法,為了更好了解,用兩個應用EM算法求解的例子進一步解釋EM的應用。

EM算法

EM算法引入

EM算法,指的是最大期望算法(Expectation Maximization Algorithm,期望最大化算法),是一種疊代算法,在統計學中被用于尋找,依賴于不可觀察的隐性變量的機率模型中,參數的最大似然估計。基本思想是首先随機取一個值去初始化待估計的參數值,然後不斷疊代尋找更優的參數使得其似然函數比原來的似然函數大。

  • EM算法當做最大似然估計的拓展,解決難以給出解析解(模型中存在隐變量)的最大似然估計(MLE)問題
  • 在算法中加入隐變量的思想可以類比為幾何題中加入一條輔助線的做法。

假定有訓練集{ x(1),x(2),...x(m) x ( 1 ) , x ( 2 ) , . . . x ( m ) },包含 m m 個獨立樣本,希望從中找到該組資料的模型p(x,z)p(x,z)的參數。

對數似然函數表達如下:

從兩個例子了解EM算法從兩個例子了解EM算法

在表達式中因為存在隐變量,直接找到參數估計比較困難,是以我們通過EM算法疊代求解下界的最大值,直到收斂。

我們通過以下的圖檔來解釋這一過程:

從兩個例子了解EM算法從兩個例子了解EM算法

圖檔上的紫色部分是我們的目标模型 p(x|θ) p ( x | θ ) 曲線,該模型比較複雜,難以直接求解其解析解,為了消除隐變量 z z 帶來的影響,我們可以得到一個不包含的zz的模型 r(x|θ) r ( x | θ ) (該函數是我們自己標明的,是以最大值可求解), 同時滿足條件 r(x|θ)≤p(x|θ) r ( x | θ ) ≤ p ( x | θ ) 。

  • 我們先取一個 θ1 θ 1 ,使得 r(x|θ1)=p(x|θ1) r ( x | θ 1 ) = p ( x | θ 1 ) (如綠線所示),然後再對此時的 r r 求其最大值,得到極值點θ2θ2,實作參數的更新。
  • 不斷重複以上過程,在更新過程中始終滿足 r≤p r ≤ p 直到收斂。

從以上過程來看,EM算法的核心就是如何找到這個 r r ,即pp的下界函數。

這個下界函數有多種方法了解,我們從Jensen不等式的角度來了解。

從兩個例子了解EM算法從兩個例子了解EM算法

上述等号成立的條件是

p(x(i),z(i);θ)Qi(z(i))=c p ( x ( i ) , z ( i ) ; θ ) Q i ( z ( i ) ) = c , ∑zQi(z(i))=1 ∑ z Q i ( z ( i ) ) = 1 ,是以:

從兩個例子了解EM算法從兩個例子了解EM算法

最終架構如下:

從兩個例子了解EM算法從兩個例子了解EM算法

EM推導高斯混合模型

高斯混合模型GMM

設有随機變量 X X , 則高斯混合模型可以用p(x)=∑Kπk(x|μk,Σk)p(x)=∑KπkN(x|μk,Σk),其中 (x|μk,Σk) N ( x | μ k , Σ k ) 表示混合模型中的第 k k 個分量πkπk表示混合系數,滿足

∑kπk=1,0≤πk≤1 ∑ k π k = 1 , 0 ≤ π k ≤ 1 。

我們知道高斯函數的機率分布為 f(x)=1(√2π)σexp(−(x−μ)22σ2) f ( x ) = 1 ( 2 π ) σ e x p ( − ( x − μ ) 2 2 σ 2 ) , 在混合高斯分布中待估計變量就包括了 μ,σ,π μ , σ , π 。

對數似然函數為 lμ,Σ,π=∑Ni=1log(∑Kk=1)πk(xi|μk,Σk)) l μ , Σ , π = ∑ i = 1 N l o g ( ∑ k = 1 K ) π k N ( x i | μ k , Σ k ) )

EM 推導過程

第一步:估算資料來自于哪個組分,即估計每一個組分生成的機率,對每個樣本 xi x i ,它由第 k k 個組份生成的機率可以記作:γ(i,k)=πk(xi|μk,Σk)∑jπj(xi|μj,Σj)γ(i,k)=πkN(xi|μk,Σk)∑jπjN(xi|μj,Σj)

第二步:估計每個組份的參數

E-step: 在給定了樣本和每個高斯分布的參數以及組份的分布函數的情況下

w(i)j=Qi(z(i)=j)=p(z(i))=j|x(i);ϕ,μ,Σ) w j ( i ) = Q i ( z ( i ) = j ) = p ( z ( i ) ) = j | x ( i ) ; ϕ , μ , Σ )

M-step:将多項式分布和高斯分布的參數帶入:

∑mi=1∑z(i)Qi(z(i))logp(x(i),z(i);ϕ,μ,Σ)Qi(z(i)) ∑ i = 1 m ∑ z ( i ) Q i ( z ( i ) ) l o g p ( x ( i ) , z ( i ) ; ϕ , μ , Σ ) Q i ( z ( i ) )

=∑mi=1∑kj=1Qi(z(i)=j)logp(x(i)|z(i)=j;ϕ,μ,Σ)p(z(i)=j;ϕ)Qi(z(i)) = ∑ i = 1 m ∑ j = 1 k Q i ( z ( i ) = j ) l o g p ( x ( i ) | z ( i ) = j ; ϕ , μ , Σ ) p ( z ( i ) = j ; ϕ ) Q i ( z ( i ) )

∑mi=1∑kj=1w(i)jlog1(2π)n2|Σj|(12)exp(−12(x(i)−μj)TΣ−1j(x(i)−μj))ϕjw(i)j ∑ i = 1 m ∑ j = 1 k w j ( i ) l o g 1 ( 2 π ) n 2 | Σ j | ( 1 2 ) e x p ( − 1 2 ( x ( i ) − μ j ) T Σ j − 1 ( x ( i ) − μ j ) ) ϕ j w j ( i )

分别對其中的未知參數求偏導數:

  • 對均值求偏導

    ∇uj∑mi=1∑kj=1w(i)jlog1(2π)n2|Σj|(12)exp(−12(x(i)−μj)TΣ−1j(x(i)−μj))ϕjw(i)j ∇ u j ∑ i = 1 m ∑ j = 1 k w j ( i ) l o g 1 ( 2 π ) n 2 | Σ j | ( 1 2 ) e x p ( − 1 2 ( x ( i ) − μ j ) T Σ j − 1 ( x ( i ) − μ j ) ) ϕ j w j ( i )

=−∇uj∑mi=1∑kj=1w(i)j12(x(i)−μj)TΣ−1j(x(i)−μj) = − ∇ u j ∑ i = 1 m ∑ j = 1 k w j ( i ) 1 2 ( x ( i ) − μ j ) T Σ j − 1 ( x ( i ) − μ j )

=∑mi=1w(i)jΣ−1j(x(i)−μj)=0 = ∑ i = 1 m w j ( i ) Σ j − 1 ( x ( i ) − μ j ) = 0

可得

μj=∑mi=1w(i)jx(i)∑mi=1w(i)j μ j = ∑ i = 1 m w j ( i ) x ( i ) ∑ i = 1 m w j ( i )

  • 對方差求導:

    ∇Σj∑mi=1∑kj=1w(i)jlog1(2π)n2|Σj|(12)exp(−12(x(i)−μj)TΣ−1j(x(i)−μj))ϕjw(i)j ∇ Σ j ∑ i = 1 m ∑ j = 1 k w j ( i ) l o g 1 ( 2 π ) n 2 | Σ j | ( 1 2 ) e x p ( − 1 2 ( x ( i ) − μ j ) T Σ j − 1 ( x ( i ) − μ j ) ) ϕ j w j ( i )

=∇Σj∑mi=1∑kj=1w(i)j(logΣ−12−12(x(i)−μj)TΣ−1j(x(i)−μj)) = ∇ Σ j ∑ i = 1 m ∑ j = 1 k w j ( i ) ( l o g Σ − 1 2 − 1 2 ( x ( i ) − μ j ) T Σ j − 1 ( x ( i ) − μ j ) )

=∑mi=1w(i)jΣ(−1)−∑mi=1w(i)j(x(i)−μj)(x(i)−μj)TΣ(−2)=0 = ∑ i = 1 m w j ( i ) Σ ( − 1 ) − ∑ i = 1 m w j ( i ) ( x ( i ) − μ j ) ( x ( i ) − μ j ) T Σ ( − 2 ) = 0

可得

Σj=∑mi=1w(i)j(x(i)−μj)(x(i)−μj)T∑mi=1w(i)j Σ j = ∑ i = 1 m w j ( i ) ( x ( i ) − μ j ) ( x ( i ) − μ j ) T ∑ i = 1 m w j ( i )

  • 對 ϕ ϕ 求偏導, 等式限制,用到拉格朗日乘子法, 删除常數項目得到:

    ∇ϕj∑mi=1∑kj=1w(i)jlog(ϕj)+β(∑kj=1ϕj−1) ∇ ϕ j ∑ i = 1 m ∑ j = 1 k w j ( i ) l o g ( ϕ j ) + β ( ∑ j = 1 k ϕ j − 1 )

    =∑mi=1w(i)jϕj+β=0 = ∑ i = 1 m w j ( i ) ϕ j + β = 0

    −β=∑mi=1∑kj=1w(i)j=m − β = ∑ i = 1 m ∑ j = 1 k w j ( i ) = m

    可得

    ϕj=1m∑mi=1w(i)j ϕ j = 1 m ∑ i = 1 m w j ( i )

EM推導PLSA模型

詳細過程可參考作者的另一篇部落格plsaEM的詳細推導

繼續閱讀