計算機小白準備逐漸重溫機器學習算法,這是學習的第一個,高斯混合模型,文章系轉載,原文位址:http://blog.csdn.net/hjimce/article/details/45244603
——————————————————————————————華麗的分割線
高斯混合算法是EM算法的一個典型的應用,EM算法的推導過程這裡不打算詳解,直接講GMM算法的實作。之前做圖像分割grab cut 算法的時候,隻知道把opencv中的高斯混合模型代碼複制下來,然後封裝成類使用,學的比較淺。結果沒過幾天發現高斯混合算法又忘了差不多了,于是用matlab去親自寫過一遍,終于發現了高斯混合模型的奧義。我的了解是高斯混合模型其實是進化版的k均值算法,是以學習高斯混合模型,最好還是把k均值算法寫過一遍。高斯混合與k均值的本質差別在于權值問題,k均值采用的是均勻權值,而高斯混合的權值需要根據高斯模型的機率進行确定。
開始學習高斯混合模型,需要先簡單複習一下單高斯模型的參數估計方法,描述一個高斯模型其實就是要計算它的均值、協方差矩陣(一維空間為方差,二維以上稱之為協方差矩陣):
假設有資料集X={x1,x2,x3……,xn},那麼用這些資料來估計單高斯模型參數的計算公式為
![](https://img.laitimes.com/img/9ZDMuAjOiMmIsIjOiQnIsICdzFWRoRXdvN1LclHdpZXYyd2LcBzNvwVZ2x2bzNXak9CX90TQNNkRrFlQKBTSvwFbslmZvwFMwQzLcVmepNHdu9mZvwFVywUNMZTY18CX052bm9CX9MGVNFTRU9UMZpWTmZEWjZXUYpVd1kmYr50MZV3YyI2cKJDT29GRjBjUIF2LcRHelR3LcJzLctmch1mclRXY39DM0cDM0QzM1ETOxQDM4EDMy8CX0Vmbu4GZzNmLn9Gbi1yZtl2Lc9CX6MHc0RHaiojIsJye.jpg)
,
OK,開始寫代碼前,先用matlab生成資料集,然後在進行聚類:
利用matlab的生成高斯模型資料集X:
mu = [2 3];
SIGMA = [1 0; 0 2];
r1 = mvnrnd(mu,SIGMA,1000);
plot(r1(:,1),r1(:,2),'r+');
然後利用上面的估計方法計算均值,和協方差是否滿足均值為[2 3],協方差為[1 0; 0 2];測試代碼如下,r2、covmat即為計算結果
[m n]=size(r1);
center=sum(r1)./m;
r2(:,)=r1(:,)-center();
r2(:,)=r1(:,)-center();
covmat=/m*r2'*r2;
先把單高斯模型的函數寫好,因為高斯混合模型是它的進化版,計算高斯混合模型過程中需要調用單高斯模型參數估計,寫好代碼後面才不會亂掉。開始高斯混合模組化之前,我先用matlab生成一個測試資料集data,如下圖,然後再進行算法測試。
生成資料集代碼如下:
%生成測試資料
mu = [ ];%測試資料1
SIGMA = [ ; 0 2];
r1 = mvnrnd(mu,SIGMA,);
plot(r1(:,),r1(:,),'.');
hold on;
mu = [ ];%測試資料2
SIGMA = [ ; 0 2];
r2 = mvnrnd(mu,SIGMA,);
plot(r2(:,),r2(:,),'.');
mu = [ ];%測試資料3
SIGMA = [ ; 0 2];
r3= mvnrnd(mu,SIGMA,);
plot(r3(:,),r3(:,),'.');
data=[r1;r2;r3];
ok,資料生成完畢,接着我們正式開始高斯混合算法解析,先看一下高斯混合模型的模組化求參步驟:
高斯混合模型的求解,說得簡單一點就是要求解高斯模型中的均值與協方差,現在我們要把上述的資料分成3類,那麼我們就是要求解3個均值及其對應的3個協方差矩陣。先講一下總體步驟,高斯混合模型包含3個步驟:
a.初始化各個高斯模型的參數,及每個高斯模型的權重;
b.根據各個高斯模型的參數及其權重,計算每個點屬于各個高斯模型的權重,計算公式為:
其中:
,Wj是每個高斯模型在這個模型所占用得權重。這個公式說的簡單一點就是每個高斯模型的權重與其機率的乘積,這樣計算出來就相當于每個高斯模型在每個資料點中的所占用的比例。
c.更新各個高斯模型的均值與方差,計算公式如下:
d.更新各個高斯模型的總權重,計算公式如下:
其實第c、d兩個步驟,無所謂順序,你完全可以總權重更新放在各個模型參數更新之前。疊代過程就是b、c、d三個步驟進行更新就可以了。OK,接着結合上面的公式寫一寫代碼。
(1)初始化高斯模型參數。
這一步初始化,在實際應用中一般是先通過k均值算法進行初始聚類,然後根據聚類結果進行計算初始化參數。不過這裡我為了測試,我們選擇随機初始化,這樣才能看出GMM算法到底能不能實作聚類。
我這裡各個高斯模型初始均值(中心)的初始化方法選擇跟k均值的初始化方法一樣,也就是随機選擇k個點位置作為k個高斯模型的初始均值。然後協方差矩陣的初始化,我選擇機關矩陣,具體代碼如下:
[m n]=size(data);
kn=;
countflag=zeros(,kn);
tdata=cell(,kn);%建立3個空矩陣
mu=cell(,kn);%建立3個空矩陣
sigma=cell(,kn);%建立3個空矩陣
%方案2 随機初始化參數
for i=:kn
mu{,i}=data(i*,:);
sigma{,i}=eye(,);
weightp(i)=/kn;
end
(2)計算各個模型在各個點的權重值
這一步是計算每個資料點屬于各個高斯混合的機率,說白了就是計算權值:
pro_ij=zeros(m,kn);%存儲每個點屬于每個類的機率
for i=:m
sumpk=;
for j=:kn
pk(j)=weightp(j)*GSMPro(mu{,j},sigma{,j},data(i,:));
sumpk=sumpk+pk(j);
end
for j=:kn
pro_ij(i,j)=pk(j)/sumpk;
end
end
(3)步驟c 更新參數
for j=:kn
[mu{,j},sigma{,j}]=WeightGSM(data,pro_ij(:,j));
end
(4)步驟d 更新各個模型的總權重
for j=:kn
weightp(j)=sum(pro_ij(:,j))/m;
end
然後把步驟2、3、4的代碼放在循環語句中進行疊代就ok了。最後貼一下整份代碼:
1、腳本檔案:
close all;
clear;
clc;
%生成測試資料
mu = [ ];%測試資料1
SIGMA = [ ; ];
r1 = mvnrnd(mu,SIGMA,);
plot(r1(:,),r1(:,),'.');
hold on;
mu = [ ];%測試資料2
SIGMA = [ ; ];
r2 = mvnrnd(mu,SIGMA,);
plot(r2(:,),r2(:,),'.');
mu = [ ];%測試資料3
SIGMA = [ ; ];
r3= mvnrnd(mu,SIGMA,);
plot(r3(:,),r3(:,),'.');
data=[r1;r2;r3];
[m n]=size(data);
kn=;
countflag=zeros(,kn);
tdata=cell(,kn);%建立10個空矩陣
mu=cell(,kn);%建立10個空矩陣
sigma=cell(,kn);%建立10個空矩陣
% 方案1 初始化采用kmeans,做參數的初步估計
% Idx=kmeans(data,kn);
% figure(2);%繪制初始化結果
% hold on;
% for i=1:m
% if Idx(i)==1
% plot(data(i,1),data(i,2),'.y');
% elseif Idx(i)==2
% plot(data(i,1),data(i,2),'.b');
% end
% end
% for i=1:m
% tdata{1,Idx(i)}=[tdata{1,Idx(i)};data(i,:)];
% end
% for i=1:kn
% [mu{1,i},sigma{1,i}]=GSMData(tdata{1,i});
% end
% for i=1:kn
% [trow,tcol]=size(tdata{1,i});
% weightp(i)=trow/m;
% end
%方案2 随機初始化
for i=:kn
mu{,i}=data(i*,:);
sigma{,i}=eye(,);
weightp(i)=/kn;
end
it=;
while it<
%E步 計算每個點處于每個類的機率
pro_ij=zeros(m,kn);%存儲每個點屬于每個類的機率
for i=:m
sumpk=;
for j=:kn
pk(j)=weightp(j)*GSMPro(mu{,j},sigma{,j},data(i,:));
sumpk=sumpk+pk(j);
end
for j=:kn
pro_ij(i,j)=pk(j)/sumpk;
end
end
%M步
for j=:kn
[mu{,j},sigma{,j}]=WeightGSM(data,pro_ij(:,j));
end
%更新權值
for j=:kn
weightp(j)=sum(pro_ij(:,j))/m;
end
sumw=sum(weightp);
it=it+;
end
for i=:m
[value index]=max(pro_ij(i,:));
Idx(i)=index;
end
figure();
hold on;
for i=:m
if Idx(i)==
plot(data(i,),data(i,),'.y');
elseif Idx(i)==
plot(data(i,),data(i,),'.b');
elseif Idx(i)==
plot(data(i,),data(i,),'.r');
end
end
% figure(3);
% %px=gmmstd(data,3);
% for i=1:m
% [value index]=max(px(i,:));
% Idx(i)=index;
% end
% hold on;
% for i=1:m
% if Idx(i)==1
% plot(data(i,1),data(i,2),'.y');
% elseif Idx(i)==2
% plot(data(i,1),data(i,2),'.b');
% elseif Idx(i)==3
% plot(data(i,1),data(i,2),'.r');
% end
% end
%單高斯模型參數估計
% [m n]=size(r1);
% center=sum(r1)./m;
% r2(:,1)=r1(:,1)-center(1);
% r2(:,2)=r1(:,2)-center(2);
% covmat=1/m*r2'*r2;
2、相關函數
function [ mu ,sigma ] = WeightGSM(data,weight)
%計算權重均值
[m n]=size(data);
sumweight=sum(weight);
weightdata=[];
for i=:m
weightdata(i,:)=weight(i)*data(i,:);
end
center=sum(weightdata)/sumweight;
%計算權重協方差
for i=:n
r2(:,i)=data(:,i)-center(i);
end
for i=:m
r1(i,:)=weight(i)*r2(i,:);
end
sigma=/sumweight*r1'*r2;
mu=center;
end
function [pro] = GSMPro(mu ,sigma,x)
pro=exp(-*(x-mu)*inv(sigma)*(x-mu)');
pro=1/sqrt(2*pi*det(sigma))*pro;
end
看以下最後的測試結果: