天天看點

自組織神經網絡SOM原理——結合例子MATLAB實作

      本文主要内容為SOM神經網絡原理的介紹,并結合執行個體給出相應的MATLAB代碼實作,友善初學者接觸學習,本人才疏學淺,如有纰漏,還望各路大神積極指點。

一、SOM神經網絡介紹

     自組織映射神經網絡, 即Self Organizing Maps (SOM), 可以對資料進行無監督學習聚類。它的思想很簡單,本質上是一種隻有輸入層--隐藏層的神經網絡。隐藏層中的一個節點代表一個需要聚成的類。訓練時采用“競争學習”的方式,每個輸入的樣例在隐藏層中找到一個和它最比對的節點,稱為它的激活節點,也叫“winning neuron”。 緊接着用随機梯度下降法更新激活節點的參數。同時,和激活節點臨近的點也根據它們距離激活節點的遠近而适當地更新參數。

      是以,SOM的一個特點是,隐藏層的節點是有拓撲關系的。這個拓撲關系需要我們确定,如果想要一維的模型,那麼隐藏節點依次連成一條線;如果想要二維的拓撲關系,那麼就行成一個平面,如下圖所示(也叫Kohonen Network):

自組織神經網絡SOM原理——結合例子MATLAB實作

      既然隐藏層是有拓撲關系的,是以我們也可以說,SOM可以把任意次元的輸入離散化到一維或者二維(更高次元的不常見)的離散空間上。 Computation layer裡面的節點與Input layer的節點是全連接配接的。

拓撲關系确定後,開始計算過程,大體分成幾個部分:

1) 初始化:每個節點随機初始化自己的參數。每個節點的參數個數與Input的次元相同。

2)對于每一個輸入資料,找到與它最相配的節點。假設輸入時D維的, 即 X={x_i, i=1,...,D},那麼判别函數可以為歐幾裡得距離:

自組織神經網絡SOM原理——結合例子MATLAB實作

3) 找到激活節點I(x)之後,我們也希望更新和它臨近的節點。令S_ij表示節點i和j之間的距離,對于I(x)臨近的節點,配置設定給它們一個更新權重:

自組織神經網絡SOM原理——結合例子MATLAB實作

簡單地說,臨近的節點根據距離的遠近,更新程度要打折扣。

4)接着就是更新節點的參數了。按照梯度下降法更新:

自組織神經網絡SOM原理——結合例子MATLAB實作

疊代,直到收斂。

二、問題描述

用26個英文字母作為SOM輸入樣本。每個字元對應一個5維向量,各字元與向量的關系如表4-2所示。由表4-2可以看出,代表A、B、C、D、E的各向量中有4個分量相同,即,是以,A、B、C、D、E應歸為一類;代表F、G、H、I、J的向量中有3個分量相同,同理也應歸為一類;依此類推。這樣就可以由表4-2中輸入向量的相似關系,将對應的字元标在圖4-8所示的樹形結構圖中。用SOM網絡對其他進行聚類分析。
自組織神經網絡SOM原理——結合例子MATLAB實作
自組織神經網絡SOM原理——結合例子MATLAB實作
自組織神經網絡SOM原理——結合例子MATLAB實作

三、MATLAB代碼實作

SOM_mian.m

%%% 神經網絡之自組織網絡SOM練習
%%%作者:xd.wp
%%%時間:2016.10.02 19:16
%% 程式說明:
%%%          1、本程式中,輸出層為二維平面,
%%%          2、幾何鄰域确定及調整權值采用exp(-distant^2/delta^2)函數
%%%          3、樣本維數為5,輸出層結點為70
%%%          4、輸入資料,歸一化為機關向量
clear all;
clc;
%% 網絡初始化及相應參數初始化
%加載資料并歸一化
[train_data,train_label]=SOM_data_process();
data_num=size(train_data,2);

%權值初始化
% weight_temp=ones(5,70)/1000;
weight_temp=rand(5,70)/1000;

%結點個數
node_num=size(weight_temp,2);

%權值歸一化
for i=1:node_num
    weight(:,i)=weight_temp(:,i)/max(weight_temp(:,i));    
end

%鄰域函數參數
delta=2;

%調整步幅
alpha=0.6;
%% Kohonen算法學習過程
for t=4:-1:1                                    %%總體疊代次數
    index_active=ones(1,node_num);              %%結點活躍标志
    for n=1:data_num                            %%每個樣本的輸入
        % 競争部分,根據最小距離确定獲勝神經元
        [j_min]=SOM_compare(weight,train_data(:,n),node_num,index_active);
        
        %去激活,確定資料結點1對1映射
        index_active(1,j_min)=0;
        
        %為後續繪圖部分服務
        index_plot(1,n)=j_min;
        [x,y]=line_to_array(j_min);
        fprintf('坐标[%d,%d]處為字元%s \n',x,y,train_label(1,n));
        
        % 學習部分網絡權值調整
        st=num2str(t-1);
        switch   st
            case '3'
                [weight]=SOM_neighb3(weight,train_data(:,n),j_min,delta,alpha);
            case '2'
                [weight]=SOM_neighb2(weight,train_data(:,n),j_min,delta,alpha);
            case '1'
                [weight]=SOM_neighb1(weight,train_data(:,n),j_min,delta,alpha);
            otherwise
                [weight]=SOM_neighb0(weight,train_data(:,n),j_min,alpha);
        end
        
    end
end
%% 繪制結點分布圖像
figure(1);
for n=1:data_num
    [x,y]=line_to_array(index_plot(1,n));
    axis([0,12,0,12]);
    text(x,y,'*');
    text(x+0.2,y+0.2,train_label(1,n));
    hold on;
end
           

SOM_data_process.m

function [train_data,train_label]=SOM_data_process()
train_data=[1 0 0 0 0;
            2 0 0 0 0;
            3 0 0 0 0;
            4 0 0 0 0;
            5 0 0 0 0;
            3 1 0 0 0;
            3 2 0 0 0;
            3 3 0 0 0;
            3 4 0 0 0;
            3 5 0 0 0;
            3 3 1 0 0;
            3 3 2 0 0;
            3 3 3 0 0;
            3 3 4 0 0;
            3 3 5 0 0;
            3 3 3 1 0;
            3 3 3 2 0;
            3 3 3 3 0;
            3 3 3 4 0;
            3 3 3 5 0;
            3 3 3 3 1;
            3 3 3 3 2;
            3 3 3 3 3;
            3 3 3 3 4;
            3 3 3 3 5;
            3 3 3 3 6];
train_label=['A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z','1','2','3','4','5','6'];
train_data=train_data';
length=size(train_data,2);
for i=1:length
     train_data(:,i)=train_data(:,i)/sqrt(sum(train_data(:,i).*train_data(:,i)));
% train_data(:,i)=train_data(:,i)/max(train_data(:,i));
end
end
           

SOM_compare.m

function [j_min]=SOM_compare(weight,train_data_active,node_num,index_active)
for j=1:node_num
    distant(j,1)=sum((weight(:,j)-train_data_active).^2);
end
[~,j_min]=min(distant);
while(index_active(1,j_min)==0)
    distant(j_min,1)=10000000;
    [~,j_min]=min(distant);
end

end
           

SOM_neighb3.m

function [weight]=SOM_neighb3(weight,train_data_active,j_min,delta,alpha)

%% 權值調整幅度分布
%                          -0.2
%                           0.2
%                           0.6
%        -0.2   0.2   0.6    1    0.6   0.2   -0.2
%                           0.6
%                           0.2
%                          -0.2
% 機關距離轉化比例為0.4
%% 坐标轉換
[x,y]=line_to_array(j_min);
% 将1*70向量中的坐标轉化為7*10矩陣中的坐标
%    1   8    ···
%    7   14   ···

%% 權值調整過程
%結點靠上邊情況
if (x<=3)
    for m=1:1:x+3
        if (y<=3)          %結點靠左邊
            for n=1:1:y+3
                distant=sqrt((x-m)^2+(y-n)^2);
                weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
            end
        elseif (y>=8)      %結點靠右邊
            for n=y-3:1:10
                distant=sqrt((x-m)^2+(y-n)^2);
                weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
            end
        else
            for n=y-3:1:y+3
                distant=sqrt((x-m)^2+(y-n)^2);
                weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
            end
        end
    end
    %結點靠下邊情況
elseif (x>=5)
    for m=x-3:1:7
         if (y<=3)          %結點靠左邊
            for n=1:1:y+3
                distant=sqrt((x-m)^2+(y-n)^2);
                weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
            end
        elseif (y>=8)      %結點靠右邊
            for n=y-3:1:10
                distant=sqrt((x-m)^2+(y-n)^2);
                weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
            end
        else
            for n=y-3:1:y+3
                distant=sqrt((x-m)^2+(y-n)^2);
                weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
            end
        end
    end
    %結點正好在中間
else
    for m=1:7
         if (y<=3)          %結點靠左邊
            for n=1:1:y+3
                distant=sqrt((x-m)^2+(y-n)^2);
                weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
            end
        elseif (y>=8)      %結點靠右邊
            for n=y-3:1:10
                distant=sqrt((x-m)^2+(y-n)^2);
                weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
            end
        else
            for n=y-3:1:y+3
                distant=sqrt((x-m)^2+(y-n)^2);
                weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
            end
        end
    end
end
end
           

SOM_neighb2.m

function [weight]=SOM_neighb2(weight,train_data_active,j_min,delta,alpha)

%% 權值調整幅度分布
%                          -0.2
%                           0.2
%                           0.6
%        -0.2   0.2   0.6    1    0.6   0.2   -0.2
%                           0.6
%                           0.2
%                          -0.2
% 機關距離轉化比例為0.4
%% 坐标轉換
[x,y]=line_to_array(j_min);
% 将1*70向量中的坐标轉化為7*10矩陣中的坐标
%    1   8    ···
%    7   14   ···

%% 權值調整過程
%結點靠上邊情況
if (x<=2)
    for m=1:1:x+2
        if (y<=2)          %結點靠左邊
            for n=1:1:y+2
                distant=sqrt((x-m)^2+(y-n)^2);
                weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
            end
        elseif (y>=9)      %結點靠右邊
            for n=y-2:1:10
                distant=sqrt((x-m)^2+(y-n)^2);
                weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
            end
        else
            for n=y-2:1:y+2
                distant=sqrt((x-m)^2+(y-n)^2);
                weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
            end
        end
    end
    %結點靠下邊情況
elseif (x>=6)
    for m=x-2:1:7
      if (y<=2)          %結點靠左邊
            for n=1:1:y+2
                distant=sqrt((x-m)^2+(y-n)^2);
                weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
            end
        elseif (y>=9)      %結點靠右邊
            for n=y-2:1:10
                distant=sqrt((x-m)^2+(y-n)^2);
                weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
            end
        else
            for n=y-2:1:y+2
                distant=sqrt((x-m)^2+(y-n)^2);
                weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
            end
      end
    end
    %結點正好在中間
else
    for m=x-2:1:x+2
         if (y<=2)          %結點靠左邊
            for n=1:1:y+2
                distant=sqrt((x-m)^2+(y-n)^2);
                weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
            end
        elseif (y>=9)      %結點靠右邊
            for n=y-2:1:10
                distant=sqrt((x-m)^2+(y-n)^2);
                weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
            end
        else
            for n=y-2:1:y+2
                distant=sqrt((x-m)^2+(y-n)^2);
                weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
            end
        end
    end
end
end
           

SOM_neighb1.m

function [weight]=SOM_neighb1(weight,train_data_active,j_min,delta,alpha)

%% 權值調整幅度分布
%                          -0.2
%                           0.2
%                           0.6
%        -0.2   0.2   0.6    1    0.6   0.2   -0.2
%                           0.6
%                           0.2
%                          -0.2
% 機關距離轉化比例為0.4
%% 坐标轉換
[x,y]=line_to_array(j_min);
% 将1*70向量中的坐标轉化為7*10矩陣中的坐标
%    1   8    ···
%    7   14   ···

%% 權值調整過程
%結點靠上邊情況
if (x<=1)
    for m=1:1:x+1
        if (y<=1)          %結點靠左邊
            for n=1:1:y+3
                distant=sqrt((x-m)^2+(y-n)^2);
                weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
            end
        elseif (y>=10)      %結點靠右邊
            for n=y-1:1:10
                distant=sqrt((x-m)^2+(y-n)^2);
weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
            end
        else
            for n=y-1:1:y+1
                distant=sqrt((x-m)^2+(y-n)^2);
                weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
            end
        end
    end
    %結點靠下邊情況
elseif (x>=7)
    for m=x-3:1:7
        if (y<=1)          %結點靠左邊
            for n=1:1:y+3
                distant=sqrt((x-m)^2+(y-n)^2);
                weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
            end
        elseif (y>=10)      %結點靠右邊
            for n=y-1:1:10
                distant=sqrt((x-m)^2+(y-n)^2);
weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
            end
        else
            for n=y-1:1:y+1
                distant=sqrt((x-m)^2+(y-n)^2);
                weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
            end
        end
    end
    %結點正好在中間
else
    for m=x-1:1:x+1
        if (y<=1)          %結點靠左邊
            for n=1:1:y+3
                distant=sqrt((x-m)^2+(y-n)^2);
                weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
            end
        elseif (y>=10)      %結點靠右邊
            for n=y-1:1:10
                distant=sqrt((x-m)^2+(y-n)^2);
weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
            end
        else
            for n=y-1:1:y+1
                distant=sqrt((x-m)^2+(y-n)^2);
                weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
            end
        end
    end
end
end
           

SOM_neighb0.m

function [weight]=SOM_neighb0(weight,train_data_active,j_min,alpha)
weight(:,j_min)=weight(:,j_min)+alpha*(weight(:,j_min)-train_data_active);
end
           

line_to_array.m

function [x,y]=line_to_array(j_min)
% 将1*70向量中的坐标轉化為7*10矩陣中的坐标
%    1   8    ···
%    7   14   ···
y=ceil(j_min/7);
x=rem(j_min,7);
end
           

四、結果顯示

不同初始條件的結果圖

自組織神經網絡SOM原理——結合例子MATLAB實作
自組織神經網絡SOM原理——結合例子MATLAB實作
自組織神經網絡SOM原理——結合例子MATLAB實作
自組織神經網絡SOM原理——結合例子MATLAB實作

繼續閱讀