EM算法
一、EM算法簡介:
EM 算法是Dempster,Laind,Rubin于1977年提出的求參數極大似然估計的一種方法,它可以從非完整資料集中對參數進行MLE估計,是一種非常簡單實用的學習算法。這種方法可以廣泛地應用于處理缺損資料、截尾資料以及帶有噪聲等所謂的不完全資料。具體地說,我們可以利用EM算法來填充樣本中的缺失資料、發現隐藏變量的值、估計HMM中的參數、估計有限混合分布中的參數以及可以進行無監督聚類等等。
最大期望算法(Expectation Maximization Algorithm,又譯為:期望最大化算法),是一種疊代算法,用于含有隐變量(hidden variable)的機率參數模型的最大似然估計或極大後驗機率估計。
在統計計算中,最大期望(EM)算法是在機率(probabilistic)模型中尋找參數最大似然估計或者最大後驗估計的算法,其中機率模型依賴于無法觀測的隐藏變量(Latent Variable)。最大期望經常用在機器學習和計算機視覺的資料聚類(Data Clustering)領域。
最大期望算法經過兩個步驟交替進行計算,第一步是計算期望(E),也就是将隐藏變量象能夠觀測到的一樣包含在内進而計算最大似然的期望值;另外一步是最大化(M),也就是最大化在 E 步上找到的最大似然的期望值進而計算參數的最大似然估計。M 步上找到的參數然後用于另外一個 E 步計算,這個過程不斷交替進行。
二、基本步驟(具體公式及推導,讀者參考其他文獻):
-
參數初始化
對需要估計的參數進行初始指派,包括均值、方差、混合系數以及期望。
-
E-Step計算
利用機率分布公式計算後驗機率,即期望。
-
M-step計算
重新估計參數,包括均值、方差、混合系數并且估計此參數下的期望值。
-
收斂性判斷
将新的與舊的值進行比較,并與設定的門檻值進行對比,判斷疊代是否結束,若不符合條件,則傳回到第2步,重新進行計算,直到收斂符合條件結束。
三、用python實作EM算法過程,下面以投硬币為例,供讀者研究:
TEST=[[5,5],[9,1],[8,2],[4,6],[7,3]];
#投出來的結果,前面是正面向上的次數,每組結果後面數字表示反面向上的次數。
#由于每次投币要不選擇A或者B,且僅從單個樣本資料,無法獲知,EM算法的主要目标是:通過大量計算和統計,将資料分離或尋找隐含變量。
print(TEST)#整個考慮可以從正面入手,反面配合。
def P(sA,sB,t1,t2):
#s是初始的機率,t1取正面的個數;t2取反面的個數,以機率密度為準。
PA=Cmn(t1+t2,t1)*(sA**t1)*((1-sA)**t2) #計算A出現的機率。
PB=Cmn(t1+t2,t1)*(sB**t1)*((1-sB)**t2) #計算B出現的機率。
return round(PA/(PA+PB),2) #求出新的機率值,然後根據這個機率值進行後面期望值的計算;且保留2位有效數字。
def fac(n):#階乘。
f=1
fory in range(2,n+1):
f=f*y
return f
def Cmn(m,n):#先定義cmn後面利用這個機率密度函數,會使用排列組合關系。當Cm,n=Cm,m-n。減少計算量。
s=m-n
ifs<n:
n=s
f=1;t=m
fory in range(0,n):
f=f*t
t=t-1
return f/fac(n)
def CoinAB(oldoA,oldoB): #計算期望值,通過一次期望值的求解,再重新疊代機率值。
UA1=0;UA2=0; #UA1和UA2是A硬币投出的期望值。
t3=0;t4=0;t5=0;t6=0; #t3和t4、t5和t6都是為計算硬币A和B的期望值。
fory in TEST: #周遊所有樣本資料。
UA1,UA2=y #取目前值,前面表示正面,後面表示反面。
oA=P(oldoA,oldoB,UA1,UA2) #計算出出A硬币的機率。
print(oA,1-oA)
t3=t3+UA1*round(oA,2) #計算A期望值(針對正面這個事實開始讨論)
t4=t4+UA2*round(oA,2)
t5=t5+UA1*round((1-oA),2) #計算B期望值(針對正面這個事實開始讨論)
t6=t6+UA2*round((1-oA),2)
return round(t3/(t3+t4),2),round(t5/(t5+t6),2) #傳回疊代新一輪的A、B的機率。
def EM(oA,oB):
y=0; #計疊代次數。
while(1):
y=y+1
oldoA=oA;oldoB=oB #先存儲疊代資料,為了計算收斂值,結束條件。
oA,oB=CoinAB(oldoA,oldoB) #分别指派,為了下次使用。
print("----y={},oA={},oB={}".format(y,oA,oB))
if (oldoA-oA)**2+(oldoB-oB)**2<0.005:#自己設定收斂條件,目的為終止循環。
break
print("oA=",oA,"oB=",oB)
#oA,oB=eval(input("請輸入初始值-oA,oB,逗号隔開:\n"))
oA=0.6
oB=0.4
EM(oA,oB)
大家,加油!