天天看點

幾種機率語言模型和參數學習方法

幾種機率語言模型和參數學習方法

From:http://blog.csdn.net/yangliuy/article/details/8330640

            http://blog.csdn.net/yangliuy/article/details/8302599

            http://blog.csdn.net/yangliuy/article/details/8457329

*********************************************************************************************************************************

第一篇 PLSA及EM算法

[本文PDF版本下載下傳位址 PLSA及EM算法-yangliuy]

前言:本文主要介紹PLSA及EM算法,首先給出LSA(隐性語義分析)的早期方法SVD,然後引入基于機率的PLSA模型,其參數學習采用EM算法。接着我們分析如何運用EM算法估計一個簡單的mixture unigram 語言模型和混合高斯模型GMM的參數,最後總結EM算法的一般形式及運用關鍵點。對于改進PLSA,引入hyperparameter的LDA模型及其Gibbs Sampling參數估計方法放在本系列後面的文章LDA及Gibbs Samping介紹。

1 LSA and SVD

LSA(隐性語義分析)的目的是要從文本中發現隐含的語義次元-即“Topic”或者“Concept”。我們知道,在文檔的空間向量模型(VSM)中,文檔被表示成由特征詞出現機率組成的多元向量,這種方法的好處是可以将query和文檔轉化成同一空間下的向量計算相似度,可以對不同詞項賦予不同的權重,在文字檢索、分類、聚類問題中都得到了廣泛應用,在基于貝葉斯算法及KNN算法的newsgroup18828文本分類器的JAVA實作和基于Kmeans算法、MBSAS算法及DBSCAN算法的newsgroup18828文本聚類器的JAVA實作系列文章中的分類聚類算法大多都是采用向量空間模型。然而,向量空間模型沒有能力處理一詞多義和一義多詞問題,例如同義詞也分别被表示成獨立的一維,計算向量的餘弦相似度時會低估使用者期望的相似度;而某個詞項有多個詞義時,始終對應同一次元,是以計算的結果會高估使用者期望的相似度。

LSA方法的引入就可以減輕類似的問題。基于SVD分解,我們可以構造一個原始向量矩陣的一個低秩逼近矩陣,具體的做法是将詞項文檔矩陣做SVD分解

幾種機率語言模型和參數學習方法

  其中

幾種機率語言模型和參數學習方法

是以詞項(terms)為行, 文檔(documents)為列做一個大矩陣. 設一共有t行d列,  矩陣的元素為詞項的tf-idf值。然後把

幾種機率語言模型和參數學習方法

的r個對角元素的前k個保留(最大的k個保留), 後面最小的r-k個奇異值置0, 得到

幾種機率語言模型和參數學習方法

;最後計算一個近似的分解矩陣

幾種機率語言模型和參數學習方法

幾種機率語言模型和參數學習方法

在最小二乘意義下是

幾種機率語言模型和參數學習方法

的最佳逼近。由于

幾種機率語言模型和參數學習方法

最多包含k個非零元素,是以

幾種機率語言模型和參數學習方法

的秩不超過k。通過在SVD分解近似,我們将原始的向量轉化成一個低維隐含語義空間中,起到了特征降維的作用。每個奇異值對應的是每個“語義”次元的權重,将不太重要的權重置為0,隻保留最重要的次元資訊,去掉一些資訊“nosie”,因而可以得到文檔的一種更優表示形式。将SVD分解降維應用到文檔聚類的JAVA實作可參見此文。

IIR中給出的一個SVD降維的執行個體如下:

幾種機率語言模型和參數學習方法
幾種機率語言模型和參數學習方法

左邊是原始矩陣的SVD分解,右邊是隻保留權重最大2維,将原始矩陣降到2維後的情況。

2 PLSA

盡管基于SVD的LSA取得了一定的成功,但是其缺乏嚴謹的數理統計基礎,而且SVD分解非常耗時。Hofmann在SIGIR'99上提出了基于機率統計的PLSA模型,并且用EM算法學習模型參數。PLSA的機率圖模型如下

幾種機率語言模型和參數學習方法
幾種機率語言模型和參數學習方法

其中D代表文檔,Z代表隐含類别或者主題,W為觀察到的單詞,

幾種機率語言模型和參數學習方法

表示單詞出現在文檔

幾種機率語言模型和參數學習方法

的機率,

幾種機率語言模型和參數學習方法

表示文檔

幾種機率語言模型和參數學習方法

中出現主題

幾種機率語言模型和參數學習方法

下的單詞的機率,

幾種機率語言模型和參數學習方法

給定主題

幾種機率語言模型和參數學習方法

出現單詞

幾種機率語言模型和參數學習方法

的機率。并且每個主題在所有詞項上服從Multinomial 分布,每個文檔在所有主題上服從Multinomial 分布。整個文檔的生成過程是這樣的:

(1) 以

幾種機率語言模型和參數學習方法

的機率選中文檔

幾種機率語言模型和參數學習方法

(2) 以

幾種機率語言模型和參數學習方法

的機率選中主題

幾種機率語言模型和參數學習方法

(3) 以

幾種機率語言模型和參數學習方法

的機率産生一個單詞。

我們可以觀察到的資料就是

幾種機率語言模型和參數學習方法

對,而

幾種機率語言模型和參數學習方法

是隐含變量。

幾種機率語言模型和參數學習方法

的聯合分布為

幾種機率語言模型和參數學習方法

幾種機率語言模型和參數學習方法

幾種機率語言模型和參數學習方法

分布對應了兩組Multinomial 分布,我們需要估計這兩組分布的參數。下面給出用EM算法估計PLSA參數的詳細推導過程。

3 Estimate parameters in PLSA  by EM

(注:本部分主要參考Tomas Hoffman, Unsupervised Learning by Probabilistic Latent Semantic Analysis.)

如文本語言模型的參數估計-最大似然估計、MAP及貝葉斯估計一文所述,常用的參數估計方法有MLE、MAP、貝葉斯估計等等。但是在PLSA中,如果我們試圖直接用MLE來估計參數,就會得到似然函數

幾種機率語言模型和參數學習方法
幾種機率語言模型和參數學習方法

其中

幾種機率語言模型和參數學習方法

是term 

幾種機率語言模型和參數學習方法

出現在文檔

幾種機率語言模型和參數學習方法

中的次數。n(di)表示文檔di中的總詞數。注意這是一個關于

幾種機率語言模型和參數學習方法

幾種機率語言模型和參數學習方法

的函數,一共有N*K + M*K個自變量(注意這裡M表示term的總數,一般文獻習慣用V表示),如果直接對這些自變量求偏導數,我們會發現由于自變量包含在對數和中,這個方程的求解很困難。是以對于這樣的包含“隐含變量”或者“缺失資料”的機率模型參數估計問題,我們采用EM算法。

EM算法的步驟是:

(1)E步驟:求隐含變量Given目前估計的參數條件下的後驗機率。

(2)M步驟:最大化Complete data對數似然函數的期望,此時我們使用E步驟裡計算的隐含變量的後驗機率,得到新的參數值。

兩步疊代進行直到收斂。

先解釋一下什麼是Incomplete data和complete data。Zhai老師在一篇經典的EM算法Notes中講到,當原始資料的似然函數很複雜時,我們通過增加一些隐含變量來增強我們的資料,得到“complete data”,而“complete data”的似然函數更加簡單,友善求極大值。于是,原始的資料就成了“incomplete data”。我們将會看到,我們可以通過最大化“complete data”似然函數的期望來最大化"incomplete data"的似然函數,以便得到求似然函數最大值更為簡單的計算途徑。

針對我們PLSA參數估計問題,在E步驟中,直接使用貝葉斯公式計算隐含變量在目前參數取值條件下的後驗機率,有

幾種機率語言模型和參數學習方法

在這個步驟中,我們假定所有的

幾種機率語言模型和參數學習方法

幾種機率語言模型和參數學習方法

都是已知的,因為初始時随機指派,後面疊代的過程中取前一輪M步驟中得到的參數值。

在M步驟中,我們最大化Complete data對數似然函數的期望。在PLSA中,Incomplete data 是觀察到的

幾種機率語言模型和參數學習方法

,隐含變量是主題

幾種機率語言模型和參數學習方法

,那麼complete data就是三元組

幾種機率語言模型和參數學習方法

,其期望是

幾種機率語言模型和參數學習方法

注意這裡

幾種機率語言模型和參數學習方法

是已知的,取的是前面E步驟裡面的估計值。下面我們來最大化期望,這又是一個多元函數求極值的問題,可以用拉格朗日乘數法。拉格朗日乘數法可以把條件極值問題轉化為無條件極值問題,在PLSA中目标函數就是

幾種機率語言模型和參數學習方法

,限制條件是

幾種機率語言模型和參數學習方法

由此我們可以寫出拉格朗日函數

幾種機率語言模型和參數學習方法

這是一個關于

幾種機率語言模型和參數學習方法

幾種機率語言模型和參數學習方法

的函數,分别對其求偏導數,我們可以得到

幾種機率語言模型和參數學習方法

注意這裡進行過方程兩邊同時乘以

幾種機率語言模型和參數學習方法

幾種機率語言模型和參數學習方法

的變形,聯立上面4組方程,我們就可以解出M步驟中通過最大化期望估計出的新的參數值

幾種機率語言模型和參數學習方法

解方程組的關鍵在于先求出

幾種機率語言模型和參數學習方法

,其實隻需要做一個加和運算就可以把

幾種機率語言模型和參數學習方法

的系數都化成1,後面就好計算了。

然後使用更新後的參數值,我們又進入E步驟,計算隐含變量

幾種機率語言模型和參數學習方法

 Given目前估計的參數條件下的後驗機率。如此不斷疊代,直到滿足終止條件。

注意到我們在M步驟中還是使用對Complete Data的MLE,那麼如果我們想加入一些先驗知識進入我們的模型,我們可以在M步驟中使用MAP估計。正如文本語言模型的參數估計-最大似然估計、MAP及貝葉斯估計中投硬币的二項分布實驗中我們加入“硬币一般是兩面均勻的”這個先驗一樣。而由此計算出的參數的估計值會在分子分母中多出關于先驗參數的preduo counts,其他步驟都是一樣的。具體可以參考Mei Qiaozhu 的Notes。

 ——————————————————

EM算法解法:

幾種機率語言模型和參數學習方法
幾種機率語言模型和參數學習方法

——————————————————

PLSA的實作也不難,網上有很多實作code。

例如這個PLSA的EM算法實作 http://ezcodesample.com/plsaidiots/PLSAjava.txt

主要的類如下(作者Andrew Polar)

[java]   view plain copy

  1. //The code is taken from:  
  2. //http://code.google.com/p/mltool4j/source/browse/trunk/src/edu/thu/mltool4j/topicmodel/plsa  
  3. //I noticed some difference with original Hofmann concept in computation of P(z). It is   
  4. //always even and actually not involved, that makes this algorithm non-negative matrix   
  5. //factoring and not PLSA.  
  6. //Found and tested by Andrew Polar.   
  7. //My version can be found on semanticsearchart.com or ezcodesample.com  

[java]   view plain copy

  1. class ProbabilisticLSA  
  2. {  
  3.     private Dataset dataset = null;  
  4.     private Posting[][] invertedIndex = null;  
  5.     private int M = -1; // number of data  
  6.     private int V = -1; // number of words  
  7.     private int K = -1; // number of topics  
  8.     public ProbabilisticLSA()  
  9.     {  
  10.     }  
  11.     public boolean doPLSA(String datafileName, int ntopics, int iters)  
  12.     {  
  13.         File datafile = new File(datafileName);  
  14.         if (datafile.exists())  
  15.         {  
  16.             if ((this.dataset = new Dataset(datafile)) == null)  
  17.             {  
  18.                 System.out.println("doPLSA, dataset == null");  
  19.                 return false;  
  20.             }  
  21.             this.M = this.dataset.size();  
  22.             this.V = this.dataset.getFeatureNum();  
  23.             this.K = ntopics;  
  24.              //build inverted index  
  25.             this.buildInvertedIndex(this.dataset);  
  26.             //run EM algorithm  
  27.             this.EM(iters);  
  28.             return true;  
  29.         }  
  30.         else  
  31.         {  
  32.             System.out.println("ProbabilisticLSA(String datafileName), datafile: " + datafileName + " doesn't exist");  
  33.             return false;  
  34.         }  
  35.     }  
  36.     //Build the inverted index for M-step fast calculation. Format:  
  37.     //invertedIndex[w][]: a unsorted list of document and position which word w  
  38.     // occurs.   
  39.     //@param ds the dataset which to be analysis  
  40.     @SuppressWarnings("unchecked")  
  41.     private boolean buildInvertedIndex(Dataset ds)  
  42.     {  
  43.         ArrayList<Posting>[] list = new ArrayList[this.V];  
  44.         for (int k=0; k<this.V; ++k) {  
  45.             list[k] = new ArrayList<Posting>();  
  46.         }  
  47.         for (int m = 0; m < this.M; m++)  
  48.         {  
  49.             Data d = ds.getDataAt(m);  
  50.             for (int position = 0; position < d.size(); position++)  
  51.             {  
  52.                 int w = d.getFeatureAt(position).dim;  
  53.                 // add posting  
  54.                 list[w].add(new Posting(m, position));  
  55.             }  
  56.         }  
  57.         // convert to array  
  58.         this.invertedIndex = new Posting[this.V][];  
  59.         for (int w = 0; w < this.V; w++)  
  60.         {  
  61.             this.invertedIndex[w] = list[w].toArray(new Posting[0]);  
  62.         }  
  63.         return true;  
  64.     }  
  65.     private boolean EM(int iters)  
  66.     {  
  67.         // p(z), size: K  
  68.         double[] Pz = new double[this.K];  
  69.         // p(d|z), size: K x M  
  70.         double[][] Pd_z = new double[this.K][this.M];  
  71.         // p(w|z), size: K x V  
  72.         double[][] Pw_z = new double[this.K][this.V];  
  73.         // p(z|d,w), size: K x M x doc.size()  
  74.         double[][][] Pz_dw = new double[this.K][this.M][];  
  75.          // L: log-likelihood value  
  76.          double L = -1;  
  77.          // run EM algorithm  
  78.          this.init(Pz, Pd_z, Pw_z, Pz_dw);  
  79.          for (int it = 0; it < iters; it++)  
  80.          {  
  81.              // E-step  
  82.              if (!this.Estep(Pz, Pd_z, Pw_z, Pz_dw))  
  83.              {  
  84.                  System.out.println("EM,  in E-step");  
  85.              }  
  86.              // M-step  
  87.              if (!this.Mstep(Pz_dw, Pw_z, Pd_z, Pz))  
  88.              {  
  89.                  System.out.println("EM, in M-step");  
  90.              }  
  91.              L = calcLoglikelihood(Pz, Pd_z, Pw_z);  
  92.              System.out.println("[" + it + "]" + "\tlikelihood: " + L);  
  93.          }  
  94.          //print result  
  95.          for (int m = 0; m < this.M; m++)  
  96.          {  
  97.              double norm = 0.0;  
  98.              for (int z = 0; z < this.K; z++) {  
  99.                  norm += Pd_z[z][m];  
  100.              }  
  101.              if (norm <= 0.0) norm = 1.0;  
  102.              for (int z = 0; z < this.K; z++) {  
  103.                  System.out.format("%10.4f", Pd_z[z][m]/norm);  
  104.              }  
  105.              System.out.println();  
  106.         }   
  107.         return false;  
  108.     }  
  109.     private boolean init(double[] Pz, double[][] Pd_z, double[][] Pw_z, double[][][] Pz_dw)  
  110.     {  
  111.         // p(z), size: K  
  112.         double zvalue = (double) 1 / (double) this.K;  
  113.         for (int z = 0; z < this.K; z++)  
  114.         {  
  115.             Pz[z] = zvalue;  
  116.         }  
  117.         // p(d|z), size: K x M  
  118.         for (int z = 0; z < this.K; z++)  
  119.         {  
  120.             double norm = 0.0;  
  121.             for (int m = 0; m < this.M; m++)  
  122.             {  
  123.                 Pd_z[z][m] = Math.random();  
  124.                 norm += Pd_z[z][m];  
  125.             }  
  126.             for (int m = 0; m < this.M; m++)  
  127.             {  
  128.                 Pd_z[z][m] /= norm;  
  129.             }  
  130.         }  
  131.         // p(w|z), size: K x V  
  132.         for (int z = 0; z < this.K; z++)  
  133.         {  
  134.             double norm = 0.0;  
  135.             for (int w = 0; w < this.V; w++)  
  136.             {  
  137.                 Pw_z[z][w] = Math.random();  
  138.                 norm += Pw_z[z][w];  
  139.             }  
  140.             for (int w = 0; w < this.V; w++)  
  141.             {  
  142.                 Pw_z[z][w] /= norm;  
  143.             }  
  144.         }  
  145.         // p(z|d,w), size: K x M x doc.size()  
  146.         for (int z = 0; z < this.K; z++)  
  147.         {  
  148.             for (int m = 0; m < this.M; m++)  
  149.             {  
  150.                 Pz_dw[z][m] = new double[this.dataset.getDataAt(m).size()];  
  151.             }  
  152.         }  
  153.         return false;  
  154.     }  
  155.     private boolean Estep(double[] Pz, double[][] Pd_z, double[][] Pw_z, double[][][] Pz_dw)  
  156.     {  
  157.         for (int m = 0; m < this.M; m++)  
  158.         {  
  159.             Data data = this.dataset.getDataAt(m);  
  160.             for (int position = 0; position < data.size(); position++)  
  161.             {  
  162.                 // get word(dimension) at current position of document m  
  163.                 int w = data.getFeatureAt(position).dim;  
  164.                 double norm = 0.0;  
  165.                 for (int z = 0; z < this.K; z++)  
  166.                 {  
  167.                     double val = Pz[z] * Pd_z[z][m] * Pw_z[z][w];  
  168.                     Pz_dw[z][m][position] = val;  
  169.                     norm += val;  
  170.                 }  
  171.                 // normalization  
  172.                 for (int z = 0; z < this.K; z++)  
  173.                 {  
  174.                     Pz_dw[z][m][position] /= norm;  
  175.                 }  
  176.             }  
  177.         }  
  178.         return true;  
  179.     }  
  180.     private boolean Mstep(double[][][] Pz_dw, double[][] Pw_z, double[][] Pd_z, double[] Pz)  
  181.     {  
  182.         // p(w|z)  
  183.         for (int z = 0; z < this.K; z++)  
  184.         {  
  185.             double norm = 0.0;  
  186.             for (int w = 0; w < this.V; w++)  
  187.             {  
  188.                 double sum = 0.0;  
  189.                 Posting[] postings = this.invertedIndex[w];  
  190.                 for (Posting posting : postings)  
  191.                 {  
  192.                     int m = posting.docID;  
  193.                     int position = posting.pos;  
  194.                     double n = this.dataset.getDataAt(m).getFeatureAt(position).weight;  
  195.                     sum += n * Pz_dw[z][m][position];  
  196.                 }  
  197.                 Pw_z[z][w] = sum;  
  198.                 norm += sum;  
  199.             }  
  200.             // normalization  
  201.             for (int w = 0; w < this.V; w++)  
  202.             {  
  203.                 Pw_z[z][w] /= norm;  
  204.             }  
  205.         }  
  206.         // p(d|z)  
  207.         for (int z = 0; z < this.K; z++)  
  208.         {  
  209.             double norm = 0.0;  
  210.             for (int m = 0; m < this.M; m++)  
  211.             {  
  212.                 double sum = 0.0;  
  213.                 Data d = this.dataset.getDataAt(m);  
  214.                 for (int position = 0; position < d.size(); position++)  
  215.                 {  
  216.                     double n = d.getFeatureAt(position).weight;  
  217.                     sum += n * Pz_dw[z][m][position];  
  218.                 }  
  219.                 Pd_z[z][m] = sum;  
  220.                 norm += sum;  
  221.             }  
  222.             // normalization  
  223.             for (int m = 0; m < this.M; m++)  
  224.             {  
  225.                 Pd_z[z][m] /= norm;  
  226.             }  
  227.         }  
  228.         //This is definitely a bug  
  229.         //p(z) values are even, but they should not be even  
  230.         double norm = 0.0;  
  231.         for (int z = 0; z < this.K; z++)  
  232.         {  
  233.             double sum = 0.0;  
  234.             for (int m = 0; m < this.M; m++)  
  235.             {  
  236.                 sum += Pd_z[z][m];  
  237.             }  
  238.             Pz[z] = sum;  
  239.             norm += sum;  
  240.        }  
  241.         // normalization  
  242.         for (int z = 0; z < this.K; z++)  
  243.         {  
  244.             Pz[z] /= norm;  
  245.             //System.out.format("%10.4f", Pz[z]);  //here you can print to see  
  246.         }  
  247.         //System.out.println();  
  248.         return true;  
  249.     }  
  250.     private double calcLoglikelihood(double[] Pz, double[][] Pd_z, double[][] Pw_z)  
  251.     {  
  252.         double L = 0.0;  
  253.         for (int m = 0; m < this.M; m++)  
  254.         {  
  255.             Data d = this.dataset.getDataAt(m);  
  256.             for (int position = 0; position < d.size(); position++)  
  257.             {  
  258.                 Feature f = d.getFeatureAt(position);  
  259.                 int w = f.dim;  
  260.                 double n = f.weight;  
  261.                 double sum = 0.0;  
  262.                 for (int z = 0; z < this.K; z++)  
  263.                 {  
  264.                     sum += Pz[z] * Pd_z[z][m] * Pw_z[z][w];  
  265.                 }  
  266.                 L += n * Math.log10(sum);  
  267.             }  
  268.         }  
  269.         return L;  
  270.     }  
  271. }  
  272. public class PLSA {  
  273.     public static void main(String[] args) {  
  274.         ProbabilisticLSA plsa = new ProbabilisticLSA();  
  275.         //the file is not used, the hard coded data is used instead, but file name should be valid,  
  276.         //just replace the name by something valid.  
  277.         plsa.doPLSA("C:\\Users\\APolar\\workspace\\PLSA\\src\\data.txt", 2, 60);  
  278.         System.out.println("end PLSA");  
  279.     }  
  280. }  

4 Estimate parameters in a simple mixture unigram language model by EM

在PLSA的參數估計中,我們使用了EM算法。EM算法經常用來估計包含“缺失資料”或者“隐含變量”模型的參數估計問題。這兩個概念是互相聯系的,當我們的模型中有“隐含變量”時,我們會認為原始資料是“不完全的資料”,因為隐含變量的值無法觀察到;反過來,當我們的資料incomplete時,我們可以通過增加隐含變量來對“缺失資料”模組化。

為了加深對EM算法的了解,下面我們來看如何用EM算法來估計一個簡單混合unigram語言模型的參數。本部分主要參考Zhai老師的EM算法Notes。

4.1 最大似然估計與隐含變量引入

所謂unigram語言模型,就是建構語言模型是抛棄所有上下文資訊,認為一個詞出現的機率與其所在位置無關,具體機率圖模型可以參見LDA及Gibbs Samping一文中的介紹。什麼是混合模型(mixture model)呢?通俗的說混合機率模型就是由最基本的機率分布比如正态分布、多元分布等經過線性組合形成的新的機率模型,比如混合高斯模型就是由K個高斯分布線性組合而得到。混合模型中産生資料的确切“component model”對我們是隐藏的。我們假設混合模型包含兩個multinomial component model,一個是背景詞生成模型

幾種機率語言模型和參數學習方法

,另一個是主題詞生成模型

幾種機率語言模型和參數學習方法

。注意這種模型組成方式在機率語言模型中很常見。為了表示單詞是哪個模型生成的,我們會為每個單詞增加一個布爾類型的控制變量。

文檔的對數似然函數為

幾種機率語言模型和參數學習方法
幾種機率語言模型和參數學習方法

為第i個文檔中的第j個詞,

幾種機率語言模型和參數學習方法

為表示文檔中背景詞比例的參數,通常根據經驗給定。是以

幾種機率語言模型和參數學習方法

是已知的,我們隻需要估計

幾種機率語言模型和參數學習方法

即可。

同樣的我們首先試圖用最大似然估計來估計參數。也就是去找最大化似然函數的參數值,有

幾種機率語言模型和參數學習方法

這是一個關于

幾種機率語言模型和參數學習方法

的函數,同樣的,

幾種機率語言模型和參數學習方法

包含在了對數和中。是以很難求解極大值,用拉格朗日乘數法,你會發現偏導數等于0得到的方程很難求解。是以我們需要依賴數值算法,而EM算法就是其中常用的一種。

我們為每個單詞引入一個布爾類型的變量z表示該單詞是background word 還是topic word.即

幾種機率語言模型和參數學習方法

這裡我們假設"complete data"不僅包含可以觀察到F中的所有單詞,而且還包括隐含的變量z。那麼根據EM算法,在E步驟我們計算“complete data”的對數似然函數有

幾種機率語言模型和參數學習方法

比較一下

幾種機率語言模型和參數學習方法

幾種機率語言模型和參數學習方法

,求和運算在對數之外進行,因為此時通過控制變量z的設定,我們明确知道了單詞是由背景詞分布還是topic 詞分布産生的。

幾種機率語言模型和參數學習方法

幾種機率語言模型和參數學習方法

的關系是怎樣的呢?如果帶估計參數是

幾種機率語言模型和參數學習方法

,原始資料是X,對于每一個原始資料配置設定了一個隐含變量H,則有

幾種機率語言模型和參數學習方法
幾種機率語言模型和參數學習方法

4.2 似然函數的下界分析

EM算法的基本思想就是初始随機給定待估計參數的值,然後通過E步驟和M步驟兩步疊代去不斷搜尋更好的參數值。更好的參數值應該要滿足使得似然函數更大。我們假設一個潛在的更好參數值是

幾種機率語言模型和參數學習方法

,第n次疊代M步驟得到的參數估計值是

幾種機率語言模型和參數學習方法

,那麼兩個參數值對應的似然函數和"complete data"的似然函數的差滿足

幾種機率語言模型和參數學習方法

我們尋找更好參數值的目标就是要最大化

幾種機率語言模型和參數學習方法

,也等價于最大化

幾種機率語言模型和參數學習方法

。我們來計算隐含變量在給定目前資料X和目前估計的參數值

幾種機率語言模型和參數學習方法

條件下的條件機率分布即

幾種機率語言模型和參數學習方法

,有

幾種機率語言模型和參數學習方法

其中右邊第三項是

幾種機率語言模型和參數學習方法

幾種機率語言模型和參數學習方法

的相對熵,總為非負值。是以我們有

幾種機率語言模型和參數學習方法
幾種機率語言模型和參數學習方法

于是我們得到了潛在更好參數值

幾種機率語言模型和參數學習方法

的incomplete data似然函數的下界。這裡我們尤其要注意右邊後兩項為常數,因為不包含

幾種機率語言模型和參數學習方法

。是以incomplete data似然函數的下界就是complete data似然函數的期望,也就是諸多EM算法講義中出現的Q函數,表達式為

幾種機率語言模型和參數學習方法

可以看出這個期望等于complete data似然函數乘以對應隐含變量條件機率再求和。對于我們要求解的問題,Q函數就是

幾種機率語言模型和參數學習方法

這裡多解釋幾句Q函數。單詞相應的變量z為0時,單詞為topic word,從多元分布

幾種機率語言模型和參數學習方法

中産生;當z為1時,單詞為background word,從多元分布

幾種機率語言模型和參數學習方法

産生。同時我們也可以看到如何求Q函數即complete data似然函數的期望,也就是我們要最大化的那個期望(EM算法最大化期望指的就是這個期望),我們要特别關注隐含變量在觀察到資料X和前一輪估計出的參數值

幾種機率語言模型和參數學習方法

條件下取不同值的機率,而隐含變量不同的值對應complete data的不同的似然函數,我們要計算的所謂的期望就是指complete data的似然函數值在不同隐含變量取值情況下的期望值。

4.3 EM算法的一般步驟

通過4.2部分的分析,我們知道,如果我們在下一輪疊代中可以找到一個更好的參數值

幾種機率語言模型和參數學習方法

使得

幾種機率語言模型和參數學習方法

那麼相應的也會有

幾種機率語言模型和參數學習方法

,是以EM算法的一般步驟如下

(1) 随機初始化參數值

幾種機率語言模型和參數學習方法

,也可以根據任何關于最佳參數取值範圍的先驗知識來初始化

幾種機率語言模型和參數學習方法

(2) 不斷兩步疊代尋找更優的參數值

幾種機率語言模型和參數學習方法

     (a) E步驟(求期望) 計算Q函數 

幾種機率語言模型和參數學習方法

     (b)M步驟(最大化)通過最大化Q函數來尋找更優的參數值

幾種機率語言模型和參數學習方法
幾種機率語言模型和參數學習方法

(3) 當似然函數

幾種機率語言模型和參數學習方法

收斂時算法停止。

這裡需要注意如何盡量保證EM算法可以找到全局最優解而不是局部最優解呢?第一種方法是嘗試許多不同的參數初始值,然後從得到的很多估計出的參數值中選取最優的;第二種方法是通過一個更簡單的模型比如隻有唯一全局最大值的模型來決定複雜模型的初始值。

通過前面的分析可以知道,EM算法的優勢在于complete data的似然函數

幾種機率語言模型和參數學習方法

更容易最大化,因為已經假定了隐含變量的取值,當然要乘以隐含變量取該值的條件機率,是以最終變成了最大化期望值。由于隐含變量變成了已知量,Q函數比原始incomplete data的似然函數更容易求最大值。是以對于“缺失資料”的情況,我們通過引入隐含變量使得complete data的似然函數容易最大化。

在E步驟中,主要的計算難點在于計算隐含變量的條件機率

幾種機率語言模型和參數學習方法

,在PLSA中就是

幾種機率語言模型和參數學習方法

在我們這個簡單混合語言模型的例子中就是

幾種機率語言模型和參數學習方法
幾種機率語言模型和參數學習方法

我們假設z的取值隻于目前那一個單詞有關,計算很容易,但是在LDA中用這種方法計算隐含變量的條件機率和最大化Q函數就比較複雜,可以參見原始LDA論文的參數推導部分。我們也可以用更簡單的Gibbs Sampling來估計參數,具體可以參見LDA及Gibbs Samping。

繼續我們的問題,下面便是M步驟。使用拉格朗日乘數法來求Q函數的最大值,限制條件是

幾種機率語言模型和參數學習方法

構造拉格朗日輔助函數

幾種機率語言模型和參數學習方法

對自變量

幾種機率語言模型和參數學習方法

求偏導數

幾種機率語言模型和參數學習方法

令偏導數為0解出來唯一的極值點

幾種機率語言模型和參數學習方法

容易知道這裡唯一的極值點就是最值點了。注意這裡Zhai老師變換了一下變量表示,把對文檔裡面詞的周遊轉化成了對詞典裡面的term的周遊,因為z的取值至于對應的那一個單詞有關,與上下文無關。是以E步驟求隐含變量的條件機率公式也相應變成了

幾種機率語言模型和參數學習方法

最後我們就得到了簡單混合Unigram語言模型的EM算法更新公式

即E步驟 求隐含變量條件機率和M步驟 最大化期望估計參數的公式

幾種機率語言模型和參數學習方法

整個計算過程我們可以看到,我們不需要明确求出Q函數的表達式。取而代之的是我們計算隐含變量的條件機率,然後通過最大化Q函數來得到新的參數估計值。

是以EM算法兩步疊代的過程實質是在尋找更好的待估計參數的值使得原始資料即incomplete data似然函數的下界不斷提升,而這個“下界“就是引入隐含變量之後的complete data似然函數的期望,也就是諸多EM算法講義中出現的Q函數,通過最大化Q函數來尋找更優的參數值。同時,上一輪估計出的參數值會在下一輪E步驟中當成已知條件計算隐含變量的條件機率,而這個條件機率又是最大化Q函數求新的參數值是所必需的。

5 Estimate parameters in GMM by EM

經過第3部分和第4部分用EM算法求解PLSA和簡單unigram混合模型參數估計問題的詳細分析,相信大部分讀者已經對EM算法有了一定了解。關于EM算法的材料包括PRML會首先介紹用EM算法去求解混合高斯模型GMM的參數估計問題。下面就讓我們來看看如何用EM算法來求解混合高斯模型GMM。

混合高斯模型GMM由K個高斯模型的線性組合組成,高斯模型就是正态分布模型,其中每個高斯模型我們成為一個”Component“,GMM的機率密度函數就是這K個高斯模型機率密度函數的線性組合即

幾種機率語言模型和參數學習方法

其中

幾種機率語言模型和參數學習方法

就是高斯分布即正态分布的機率密度函數。這是x為向量的情況,對于x為标量的情況就是

幾種機率語言模型和參數學習方法

大部分讀者應該對标量情形的機率分布更熟悉。這裡啰嗦幾句,最近看機器學習的論文和書籍,裡面的随機變量基本都是多元向量,向量的計算比如加減乘除和求導運算都和标量運算有一些差別,尤其是求導運算,向量和矩陣的求導運算會麻煩很多,看pluskid推薦的一本冊子Matrix Cookbook,裡面有很多矩陣求導公式,直接查閱應該會更友善。

下面繼續說GMM。根據上面給出的機率密度函數,如果我們要從 GMM 的分布中Sample一個樣本,實際上可以分為兩步:首先随機地在這 

幾種機率語言模型和參數學習方法

 個 Component 之中選一個,每個 Component 被選中的機率實際上就是它的系數 

幾種機率語言模型和參數學習方法

 ,選中了 Component 之後,再單獨地考慮從這個 Component 的分布中選取一個樣本點就可以了。在PRML上,引入了一個K維二值随機變量z,隻有1維是1,其他維都是0。唯一那個非零的維對應的就是GMM參數樣本時被選中的那個高斯分布,而某一維非零的機率就是

幾種機率語言模型和參數學習方法

,即

幾種機率語言模型和參數學習方法

下面我們開始估計GMM的參數,包括這K個高斯分布的所有均值和方差以及線性組合的系數。我們給每個樣本資料增加一個隐含變量

幾種機率語言模型和參數學習方法

, 就是上面所說的K維向量,表明了

幾種機率語言模型和參數學習方法

是從哪個高斯分布中sample出來的。對應的機率圖模型就是

幾種機率語言模型和參數學習方法

觀察變量的對數似然函數為

幾種機率語言模型和參數學習方法

令對

幾種機率語言模型和參數學習方法

的偏導數等于0我們有

幾種機率語言模型和參數學習方法

注意這裡我們定義了

幾種機率語言模型和參數學習方法

表示後驗機率

幾種機率語言模型和參數學習方法

,也就是第n個樣本是有第k個高斯分布産生的機率。可以解出

幾種機率語言模型和參數學習方法
幾種機率語言模型和參數學習方法

就是由第K個高斯分布産生的樣本點的總數;用聚類的觀點看,就是聚到cluster k的樣本點總數。然後我們将對數似然函數對

幾種機率語言模型和參數學習方法

求偏導數,令偏導數為0,得到協方差矩陣

幾種機率語言模型和參數學習方法

最後我們求系數

幾種機率語言模型和參數學習方法

。注意到系數的和為1,即

幾種機率語言模型和參數學習方法

這就是限制條件,最大化對數似然函數又成為了條件極值問題。我們仍然用拉格朗日乘數法,構造輔助函數如下

幾種機率語言模型和參數學習方法

幾種機率語言模型和參數學習方法

求導數,令導數為0有

幾種機率語言模型和參數學習方法
幾種機率語言模型和參數學習方法

這樣我們就估計出來系數項。

是以用EM算法估計GMM參數的步驟如下

(1) E步驟:估計資料由每個 Component 生成的機率:對于每個資料

幾種機率語言模型和參數學習方法

  來說,它由第 

幾種機率語言模型和參數學習方法

 個 Component 生成的機率為

幾種機率語言模型和參數學習方法

注意裡面 

幾種機率語言模型和參數學習方法

 和 

幾種機率語言模型和參數學習方法

 也是需要我們估計的值,在E步驟我們假定 

幾種機率語言模型和參數學習方法

 和 

幾種機率語言模型和參數學習方法

 均已知,我們使用上一次疊代所得的值(或者初始值)。

(2)M步驟:由最大估計求出高斯分布的所有均值、方差和線性組合的系數,更新待估計的參數值,根據上面的推導,計算公式是

幾種機率語言模型和參數學習方法

其中

幾種機率語言模型和參數學習方法

(3)重複疊代E步驟和M步驟,直到似然函數

幾種機率語言模型和參數學習方法

收斂時算法停止。

更多關于EM算法的深入分析,可以參考PRML第9章内容。

最後我們給出用EM算法估計GMM參數的Matlab實作,出自pluskid的部落格

[plain]   view plain copy

  1. function varargout = gmm(X, K_or_centroids)  
  2. % ============================================================  
  3. % Expectation-Maximization iteration implementation of  
  4. % Gaussian Mixture Model.  
  5. %  
  6. % PX = GMM(X, K_OR_CENTROIDS)  
  7. % [PX MODEL] = GMM(X, K_OR_CENTROIDS)  
  8. %  
  9. %  - X: N-by-D data matrix.  
  10. %  - K_OR_CENTROIDS: either K indicating the number of  
  11. %       components or a K-by-D matrix indicating the  
  12. %       choosing of the initial K centroids.  
  13. %  
  14. %  - PX: N-by-K matrix indicating the probability of each  
  15. %       component generating each point.  
  16. %  - MODEL: a structure containing the parameters for a GMM:  
  17. %       MODEL.Miu: a K-by-D matrix.  
  18. %       MODEL.Sigma: a D-by-D-by-K matrix.  
  19. %       MODEL.Pi: a 1-by-K vector.  
  20. % ============================================================  
  21.     threshold = 1e-15;  
  22.     [N, D] = size(X);  
  23.     if isscalar(K_or_centroids)  
  24.         K = K_or_centroids;  
  25.         % randomly pick centroids  
  26.         rndp = randperm(N);  
  27.         centroids = X(rndp(1:K), :);  
  28.     else  
  29.         K = size(K_or_centroids, 1);  
  30.         centroids = K_or_centroids;  
  31.     end  
  32.     % initial values  
  33.     [pMiu pPi pSigma] = init_params();  
  34.     Lprev = -inf;  
  35.     while true  
  36.         Px = calc_prob();  
  37.         % new value for pGamma  
  38.         pGamma = Px .* repmat(pPi, N, 1);  
  39.         pGamma = pGamma ./ repmat(sum(pGamma, 2), 1, K);  
  40.         % new value for parameters of each Component  
  41.         Nk = sum(pGamma, 1);  
  42.         pMiu = diag(1./Nk) * pGamma' * X;  
  43.         pPi = Nk/N;  
  44.         for kk = 1:K  
  45.             Xshift = X-repmat(pMiu(kk, :), N, 1);  
  46.             pSigma(:, :, kk) = (Xshift' * ...  
  47.                 (diag(pGamma(:, kk)) * Xshift)) / Nk(kk);  
  48.         end  
  49.         % check for convergence  
  50.         L = sum(log(Px*pPi'));  
  51.         if L-Lprev < threshold  
  52.             break;  
  53.         end  
  54.         Lprev = L;  
  55.     end  
  56.     if nargout == 1  
  57.         varargout = {Px};  
  58.     else  
  59.         model = [];  
  60.         model.Miu = pMiu;  
  61.         model.Sigma = pSigma;  
  62.         model.Pi = pPi;  
  63.         varargout = {Px, model};  
  64.     end  
  65.     function [pMiu pPi pSigma] = init_params()  
  66.         pMiu = centroids;  
  67.         pPi = zeros(1, K);  
  68.         pSigma = zeros(D, D, K);  
  69.         % hard assign x to each centroids  
  70.         distmat = repmat(sum(X.*X, 2), 1, K) + ...  
  71.             repmat(sum(pMiu.*pMiu, 2)', N, 1) - ...  
  72.             2*X*pMiu';  
  73.         [dummy labels] = min(distmat, [], 2);  
  74.         for k=1:K  
  75.             Xk = X(labels == k, :);  
  76.             pPi(k) = size(Xk, 1)/N;  
  77.             pSigma(:, :, k) = cov(Xk);  
  78.         end  
  79.     end  
  80.     function Px = calc_prob()  
  81.         Px = zeros(N, K);  
  82.         for k = 1:K  
  83.             Xshift = X-repmat(pMiu(k, :), N, 1);  
  84.             inv_pSigma = inv(pSigma(:, :, k));  
  85.             tmp = sum((Xshift*inv_pSigma) .* Xshift, 2);  
  86.             coef = (2*pi)^(-D/2) * sqrt(det(inv_pSigma));  
  87.             Px(:, k) = coef * exp(-0.5*tmp);  
  88.         end  
  89.     end  
  90. end  

6 全文總結

本文主要介紹PLSA及EM算法,首先給出LSA(隐性語義分析)的早期方法SVD,然後引入基于機率的PLSA模型,接着我們詳細分析了如何用EM算法估計PLSA、混合unigram 語言模型及混合高斯模型的參數過程,并總結了EM算法的一般形式和運用關鍵點。關于EM算法收斂性的證明可以參考斯坦福機器學習課程CS229 Andrew Ng老師的課程notes和JerryLead的筆記。EM算法在”缺失資料“和包含”隐含變量“的機率模型參數估計問題中非常常用,是機器學習、資料挖掘及NLP研究必須掌握的算法。

 參考文獻及推薦Notes

本文主要參考了Hoffman的PLSA論文、Zhai老師的EM Notes以及PRML第9章内容。

[1] Christopher M. Bishop. Pattern Recognition and Machine Learning (Information Science and Statistics). Springer-Verlag New York, Inc., Secaucus, NJ, USA, 2006.

[2] Gregor Heinrich. Parameter estimation for text analysis. Technical report, 2004.

[3] Wayne Xin Zhao, Note for pLSA and LDA, Technical report, 2011.

[4] Freddy Chong Tat Chua. Dimensionality reduction and clustering of text documents.Technical report, 2009.

[5] CX Zhai, A note on the expectation-maximization (em) algorithm 2007

[6] Qiaozhu Mei, A Note on EM Algorithm for Probabilistic Latent Semantic Analysis 2008

[7] pluskid, 漫談Clustering, Gaussina Mixture Model

[8] Christopher D. Manning, Prabhakar Raghavan and Hinrich Schütze, Introduction to Information Retrieval, Cambridge University Press. 2008.

[9] Tomas Hoffman, Unsupervised Learning by Probabilistic Latent Semantic Analysis. 2011

****************************************************************************************************************

第二篇 LDA及Gibbs Sampling

[本文PDF版本下載下傳位址 LDA及Gibbs Sampling-yangliuy]

 1 LDA概要      

 LDA是由Blei,Ng, Jordan 2002年發表于JMLR的機率語言模型,應用到文本模組化範疇,就是對文本進行“隐性語義分析”(LSA),目的是要以無指導學習的方法從文本中發現隐含的語義次元-即“Topic”或者“Concept”。隐性語義分析的實質是要利用文本中詞項(term)的共現特征來發現文本的Topic結構,這種方法不需要任何關于文本的背景知識。文本的隐性語義表示可以對“一詞多義”和“一義多詞”的語言現象進行模組化,這使得搜尋引擎系統得到的搜尋結果與使用者的query在語義層次上match,而不是僅僅隻是在詞彙層次上出現交集。

2 機率基礎

2.1 随機生成過程及共轭分布

     要了解LDA首先要了解随機生成過程。用随機生成過程的觀點來看,文本是一系列服從一定機率分布的詞項的樣本集合。最常用的分布就是Multinomial分布,即多項分布,這個分布是二項分布拓展到K維的情況,比如投擲骰子實驗,N次實驗結果服從K=6的多項分布。相應的,二項分布的先驗Beta分布也拓展到K維,稱為Dirichlet分布。在機率語言模型中,通常為Multinomial分布選取的先驗分布是Dirichlet分布,因為它們是共轭分布,可以帶來計算上的友善性。什麼是共轭分布呢?在文本語言模型的參數估計-最大似然估計、MAP及貝葉斯估計一文中我們可以看到,當我們為二項分布的參數p選取的先驗分布是Beta分布時,以p為參數的二項分布用貝葉斯估計得到的後驗機率仍然服從Beta分布,由此我們說二項分布和Beta分布是共轭分布。這就是共轭分布要滿足的性質。在LDA中,每個文檔中詞的Topic分布服從Multinomial分布,其先驗選取共轭先驗即Dirichlet分布;每個Topic下詞的分布服從Multinomial分布,其先驗也同樣選取共轭先驗即Dirichlet分布。

 2.2 Multinomial分布和 Dirichlet分布

    上面從二項分布和Beta分布出發引出了Multinomial分布和Dirichlet分布。這兩個分布在機率語言模型中很常用,讓我們深入了解這兩個分布。Multinomial分布的分布律如下

幾種機率語言模型和參數學習方法

   多項分布來自N次獨立重複實驗,每次實驗結果可能有K種,式子中

幾種機率語言模型和參數學習方法

為實驗結果向量,N為實驗次數,

幾種機率語言模型和參數學習方法

為出現每種實驗結果的機率組成的向量,這個公式給出了出現所有實驗結果的機率計算方法。當K=2時就是二項分布,K=6時就是投擲骰子實驗。很好了解,前面的系數其實是枚舉實驗結果的不同出現順序,即

幾種機率語言模型和參數學習方法

後面表示第K種實驗結果出現了

幾種機率語言模型和參數學習方法

次,是以是機率的相應次幂再求乘積。但是如果我們不考慮文本中詞出現的順序性,這個系數就是1。 本文後面的部分可以看出這一點。顯然有

幾種機率語言模型和參數學習方法

各維之和為1,所有

幾種機率語言模型和參數學習方法

之和為N。

    Dirichlet分布可以看做是“分布之上的分布”,從Dirichlet分布上Draw出來的每個樣本就是多項分布的參數向量

幾種機率語言模型和參數學習方法

。其分布律為

幾種機率語言模型和參數學習方法
幾種機率語言模型和參數學習方法

為Dirichlet分布的參數,在機率語言模型中通常會根據經驗給定,由于是參數向量

幾種機率語言模型和參數學習方法

服從分布的參數,是以稱為“hyperparamer”。

幾種機率語言模型和參數學習方法

是Dirichlet delta函數,可以看做是Beta函數拓展到K維的情況,但是在有的文獻中也直接寫成

幾種機率語言模型和參數學習方法

。根據Dirichlet分布在

幾種機率語言模型和參數學習方法

上的積分為1(機率的基本性質),我們可以得到一個重要的公式

幾種機率語言模型和參數學習方法

這個公式在後面LDA的參數Inference中經常使用。下圖給出了一個Dirichlet分布的執行個體

幾種機率語言模型和參數學習方法

在許多應用場合,我們使用對稱Dirichlet分布,其參數是兩個标量:維數K和參數向量各維均值

幾種機率語言模型和參數學習方法

. 其分布律如下

幾種機率語言模型和參數學習方法

關于Dirichlet分布,維基百科上有一張很有意思的圖如下

幾種機率語言模型和參數學習方法

這個圖将Dirichlet分布的機率密度函數取對數

幾種機率語言模型和參數學習方法

并且使用對稱Dirichlet分布,取K=3,也就是有兩個獨立參數 

幾種機率語言模型和參數學習方法

 ,分别對應圖中的兩個坐标軸,第三個參數始終滿足

幾種機率語言模型和參數學習方法

且 

幾種機率語言模型和參數學習方法

 ,圖中反映的是

幾種機率語言模型和參數學習方法

從0.3變化到2.0的機率對數值的變化情況。

3 unigram model

我們先介紹比較簡單的unigram model。其機率圖模型圖示如下

幾種機率語言模型和參數學習方法

關于機率圖模型尤其是貝葉斯網絡的介紹可以參見 Stanford機率圖模型(Probabilistic Graphical Model)— 第一講 貝葉斯網絡基礎一文。簡單的說,貝葉斯網絡是一個有向無環圖,圖中的結點是随機變量,圖中的有向邊代表了随機變量的條件依賴關系。unigram model假設文本中的詞服從Multinomial分布,而Multinomial分布的先驗分布為Dirichlet分布。圖中雙線圓圈

幾種機率語言模型和參數學習方法

表示我們在文本中觀察到的第n個詞,

幾種機率語言模型和參數學習方法

表示文本中一共有N個詞。加上方框表示重複,就是說一共有N個這樣的随機變量

幾種機率語言模型和參數學習方法

幾種機率語言模型和參數學習方法

幾種機率語言模型和參數學習方法

是隐含未知變量,分别是詞服從的Multinomial分布的參數和該Multinomial分布的先驗Dirichlet分布的參數。一般

幾種機率語言模型和參數學習方法

由經驗事先給定,

幾種機率語言模型和參數學習方法

由觀察到的文本中出現的詞學習得到,表示文本中出現每個詞的機率。

4 LDA

 了解了unigram model之後,我們來看LDA。我們可以假想有一位大作家,比如莫言,他現在要寫m篇文章,一共涉及了K個Topic,每個Topic下的詞分布為一個從參數為

幾種機率語言模型和參數學習方法

的Dirichlet先驗分布中sample出來的Multinomial分布(注意詞典由term構成,每篇文章由word構成,前者不能重複,後者可以重複)。對于每篇文章,他首先會從一個泊松分布中sample一個值作為文章長度,再從一個參數為

幾種機率語言模型和參數學習方法

的Dirichlet先驗分布中sample出一個Multinomial分布作為該文章裡面出現每個Topic下詞的機率;當他想寫某篇文章中的第n個詞的時候,首先從該文章中出現每個Topic的Multinomial分布中sample一個Topic,然後再在這個Topic對應的詞的Multinomial分布中sample一個詞作為他要寫的詞。不斷重複這個随機生成過程,直到他把m篇文章全部寫完。這就是LDA的一個形象通俗的解釋。用數學的語言描述就是如下過程

幾種機率語言模型和參數學習方法

轉化成機率圖模型表示就是

幾種機率語言模型和參數學習方法

圖中K為主題個數,M為文檔總數,

幾種機率語言模型和參數學習方法

是第m個文檔的單詞總數。

幾種機率語言模型和參數學習方法

 是每個Topic下詞的多項分布的Dirichlet先驗參數, 

幾種機率語言模型和參數學習方法

  是每個文檔下Topic的多項分布的Dirichlet先驗參數。

幾種機率語言模型和參數學習方法

是第m個文檔中第n個詞的主題,

幾種機率語言模型和參數學習方法

是m個文檔中的第n個詞。剩下來的兩個隐含變量

幾種機率語言模型和參數學習方法

幾種機率語言模型和參數學習方法

分别表示第m個文檔下的Topic分布和第k個Topic下詞的分布,前者是k維(k為Topic總數)向量,後者是v維向量(v為詞典中term總數)。

    給定一個文檔集合,

幾種機率語言模型和參數學習方法

是可以觀察到的已知變量,

幾種機率語言模型和參數學習方法

幾種機率語言模型和參數學習方法

是根據經驗給定的先驗參數,其他的變量

幾種機率語言模型和參數學習方法

幾種機率語言模型和參數學習方法

幾種機率語言模型和參數學習方法

都是未知的隐含變量,也是我們需要根據觀察到的變量來學習估計的。根據LDA的圖模型,我們可以寫出所有變量的聯合分布

幾種機率語言模型和參數學習方法

那麼一個詞

幾種機率語言模型和參數學習方法

初始化為一個term t的機率是

幾種機率語言模型和參數學習方法

也就是每個文檔中出現topic k的機率乘以topic k下出現term t的機率,然後枚舉所有topic求和得到。整個文檔集合的似然函數就是

幾種機率語言模型和參數學習方法

5 用Gibbs Sampling學習LDA

5.1   Gibbs Sampling的流程

 從第4部分的分析我們知道,LDA中的變量

幾種機率語言模型和參數學習方法

幾種機率語言模型和參數學習方法

幾種機率語言模型和參數學習方法

都是未知的隐含變量,也是我們需要根據觀察到的文檔集合中的詞來學習估計的,那麼如何來學習估計呢?這就是機率圖模型的Inference問題。主要的算法分為exact inference和approximate inference兩類。盡管LDA是最簡單的Topic Model, 但是其用exact inference還是很困難的,一般我們采用approximate inference算法來學習LDA中的隐含變量。比如LDA原始論文Blei02中使用的mean-field variational expectation maximisation 算法和Griffiths02中使用的Gibbs Sampling,其中Gibbs Sampling 更為簡單易懂。

    Gibbs Sampling 是Markov-Chain Monte Carlo算法的一個特例。這個算法的運作方式是每次選取機率向量的一個次元,給定其他次元的變量值Sample目前次元的值。不斷疊代,直到收斂輸出待估計的參數。可以圖示如下

幾種機率語言模型和參數學習方法

   初始時随機給文本中的每個單詞配置設定主題

幾種機率語言模型和參數學習方法

,然後統計每個主題z下出現term t的數量以及每個文檔m下出現主題z中的詞的數量,每一輪計算

幾種機率語言模型和參數學習方法

,即排除目前詞的主題配置設定,根據其他所有詞的主題配置設定估計目前詞配置設定各個主題的機率。當得到目前詞屬于所有主題z的機率分布後,根據這個機率分布為該詞sample一個新的主題

幾種機率語言模型和參數學習方法

。然後用同樣的方法不斷更新下一個詞的主題,直到發現每個文檔下Topic分布

幾種機率語言模型和參數學習方法

和每個Topic下詞的分布

幾種機率語言模型和參數學習方法

收斂,算法停止,輸出待估計的參數

幾種機率語言模型和參數學習方法

幾種機率語言模型和參數學習方法

,最終每個單詞的主題

幾種機率語言模型和參數學習方法

也同時得出。實際應用中會設定最大疊代次數。每一次計算

幾種機率語言模型和參數學習方法

的公式稱為Gibbs updating rule.下面我們來推導LDA的聯合分布和Gibbs updating rule。

5.2   LDA的聯合分布

由LDA的機率圖模型,我們可以把LDA的聯合分布寫成

幾種機率語言模型和參數學習方法

第一項和第二項因子分别可以寫成

幾種機率語言模型和參數學習方法
幾種機率語言模型和參數學習方法

可以發現兩個因子的展開形式很相似,第一項因子是給定主題Sample詞的過程,可以拆分成從Dirichlet先驗中SampleTopic Z下詞的分布

幾種機率語言模型和參數學習方法

和從參數為

幾種機率語言模型和參數學習方法

的多元分布中Sample詞這兩個步驟,是以是Dirichlet分布和Multinomial分布的機率密度函數相乘,然後在

幾種機率語言模型和參數學習方法

上積分。注意這裡用到的多元分布沒有考慮詞的順序性,是以沒有前面的系數項。

幾種機率語言模型和參數學習方法

表示term t被觀察到配置設定topic z的次數,

幾種機率語言模型和參數學習方法

表示topic k配置設定給文檔m中的word的次數.此為這裡面還用到了2.2部分中導出的一個公式

幾種機率語言模型和參數學習方法

是以這些積分都可以轉化成Dirichlet delta函數,并不需要算積分。第二個因子是給定文檔,sample目前詞的主題的過程。由此LDA的聯合分布就可以轉化成全部由Dirichlet delta函數組成的表達式

幾種機率語言模型和參數學習方法

這個式子在後面推導Gibbs updating rule時需要使用。

5.3   Gibbs updating rule

得到LDA的聯合分布後,我們就可以推導Gibbs updating rule,即排除目前詞的主題配置設定,根據其他詞的主題配置設定和觀察到的單詞來計算目前詞主題的機率公式

幾種機率語言模型和參數學習方法

裡面用到了伽馬函數的性質

幾種機率語言模型和參數學習方法

同時需要注意到

幾種機率語言模型和參數學習方法

這一項與目前詞的主題配置設定無關,因為無論配置設定那個主題,對所有k求和的結果都是一樣的,差別隻在于拿掉的是哪個主題下的一個詞。是以可以當成常量,最後我們隻需要得到一個成正比的計算式來作為Gibbs updating rule即可。

5.4 Gibbs sampling algorithm

當Gibbs sampling 收斂後,我們需要根據最後文檔集中所有單詞的主題配置設定來計算

幾種機率語言模型和參數學習方法

幾種機率語言模型和參數學習方法

,作為我們估計出來的機率圖模型中的隐含變量。每個文檔上Topic的後驗分布和每個Topic下的term後驗分布如下

幾種機率語言模型和參數學習方法

可以看出這兩個後驗分布和對應的先驗分布一樣,仍然為Dirichlet分布,這也是共轭分布的性質決定的。

使用Dirichlet分布的期望計算公式

幾種機率語言模型和參數學習方法

我們可以得到兩個Multinomial分布的參數

幾種機率語言模型和參數學習方法

幾種機率語言模型和參數學習方法

的計算公式如下

幾種機率語言模型和參數學習方法

綜上所述,用Gibbs Sampling 學習LDA參數的算法僞代碼如下

幾種機率語言模型和參數學習方法

關于這個算法的代碼實作可以參見

* yangliuy's LDAGibbsSampling https://github.com/yangliuy/LDAGibbsSampling

* Gregor Heinrich's LDA-J

* Yee Whye Teh's Gibbs LDA Matlab codes

* Mark Steyvers and Tom Griffiths's topic modeling matlab toolbox

* GibbsLDA++

6 參考文獻及推薦Notes

本文部分公式及圖檔來自 Parameter estimation for text analysis,感謝Gregor Heinrich詳實細緻的Technical report。 看過的一些關于LDA和Gibbs Sampling 的Notes, 這個是最準确細緻的,内容最為全面系統。下面幾個Notes對Topic Model感興趣的朋友也推薦看一看。

[1] Christopher M. Bishop. Pattern Recognition and Machine Learning (Information Science and Statistics). Springer-Verlag New York, Inc., Secaucus, NJ, USA, 2006.

[2] Gregor Heinrich. Parameter estimation for text analysis. Technical report, 2004.

[3] Wang Yi. Distributed Gibbs Sampling of Latent Topic Models: The Gritty Details Technical report, 2005.

[4] Wayne Xin Zhao, Note for pLSA and LDA, Technical report, 2011.

[5] Freddy Chong Tat Chua. Dimensionality reduction and clustering of text documents.Technical report, 2009.

[6] Wikipedia, Dirichlet distribution , http://en.wikipedia.org/wiki/Dirichlet_distribution

********************************************************************************************************************

第三章:LDA Gibbs Sampling的JAVA 實作

在本系列博文的前兩篇,我們系統介紹了PLSA, LDA以及它們的參數Inference 方法,重點分析了模型表示和公式推導部分。曾有位學者說,“做研究要頂天立地”,意思是說做研究空有模型和理論還不夠,我們還得有紮實的程式code和真實資料的實驗結果來作為支撐。本文就重點分析 LDA Gibbs Sampling的JAVA 實作,并給出apply到newsgroup18828新聞文檔集上得出的Topic模組化結果。

本項目Github位址 https://github.com/yangliuy/LDAGibbsSampling

1、文檔集預處理

要用LDA對文本進行topic模組化,首先要對文本進行預處理,包括token,去停用詞,stem,去noise詞,去掉低頻詞等等。當語料庫比較大時,我們也可以不進行stem。然後将文本轉換成term的index表示形式,因為後面實作LDA的過程中經常需要在term和index之間進行映射。Documents類的實作如下,裡面定義了Document内部類,用于描述文本集合中的文檔。

[java]   view plain copy

  1. package liuyang.nlp.lda.main;  
  2. import java.io.File;  
  3. import java.util.ArrayList;  
  4. import java.util.HashMap;  
  5. import java.util.Map;  
  6. import java.util.regex.Matcher;  
  7. import java.util.regex.Pattern;  
  8. import liuyang.nlp.lda.com.FileUtil;  
  9. import liuyang.nlp.lda.com.Stopwords;  
  10. public class Documents {  
  11.     ArrayList<Document> docs;   
  12.     Map<String, Integer> termToIndexMap;  
  13.     ArrayList<String> indexToTermMap;  
  14.     Map<String,Integer> termCountMap;  
  15.     public Documents(){  
  16.         docs = new ArrayList<Document>();  
  17.         termToIndexMap = new HashMap<String, Integer>();  
  18.         indexToTermMap = new ArrayList<String>();  
  19.         termCountMap = new HashMap<String, Integer>();  
  20.     }  
  21.     public void readDocs(String docsPath){  
  22.         for(File docFile : new File(docsPath).listFiles()){  
  23.             Document doc = new Document(docFile.getAbsolutePath(), termToIndexMap, indexToTermMap, termCountMap);  
  24.             docs.add(doc);  
  25.         }  
  26.     }  
  27.     public static class Document {    
  28.         private String docName;  
  29.         int[] docWords;  
  30.         public Document(String docName, Map<String, Integer> termToIndexMap, ArrayList<String> indexToTermMap, Map<String, Integer> termCountMap){  
  31.             this.docName = docName;  
  32.             //Read file and initialize word index array  
  33.             ArrayList<String> docLines = new ArrayList<String>();  
  34.             ArrayList<String> words = new ArrayList<String>();  
  35.             FileUtil.readLines(docName, docLines);  
  36.             for(String line : docLines){  
  37.                 FileUtil.tokenizeAndLowerCase(line, words);  
  38.             }  
  39.             //Remove stop words and noise words  
  40.             for(int i = 0; i < words.size(); i++){  
  41.                 if(Stopwords.isStopword(words.get(i)) || isNoiseWord(words.get(i))){  
  42.                     words.remove(i);  
  43.                     i--;  
  44.                 }  
  45.             }  
  46.             //Transfer word to index  
  47.             this.docWords = new int[words.size()];  
  48.             for(int i = 0; i < words.size(); i++){  
  49.                 String word = words.get(i);  
  50.                 if(!termToIndexMap.containsKey(word)){  
  51.                     int newIndex = termToIndexMap.size();  
  52.                     termToIndexMap.put(word, newIndex);  
  53.                     indexToTermMap.add(word);  
  54.                     termCountMap.put(word, new Integer(1));  
  55.                     docWords[i] = newIndex;  
  56.                 } else {  
  57.                     docWords[i] = termToIndexMap.get(word);  
  58.                     termCountMap.put(word, termCountMap.get(word) + 1);  
  59.                 }  
  60.             }  
  61.             words.clear();  
  62.         }  
  63.         public boolean isNoiseWord(String string) {  
  64.             // TODO Auto-generated method stub  
  65.             string = string.toLowerCase().trim();  
  66.             Pattern MY_PATTERN = Pattern.compile(".*[a-zA-Z]+.*");  
  67.             Matcher m = MY_PATTERN.matcher(string);  
  68.             // filter @xxx and URL  
  69.             if(string.matches(".*www\\..*") || string.matches(".*\\.com.*") ||   
  70.                     string.matches(".*http:.*") )  
  71.                 return true;  
  72.             if (!m.matches()) {  
  73.                 return true;  
  74.             } else  
  75.                 return false;  
  76.         }  
  77.     }  
  78. }  

2 LDA Gibbs Sampling

文本預處理完畢後我們就可以實作LDA Gibbs Sampling。 首先我們要定義需要的參數,我的實作中在程式中給出了參數預設值,同時也支援配置檔案覆寫,程式預設優先選用配置檔案的參數設定。整個算法流程包括模型初始化,疊代Inference,不斷更新主題和待估計參數,最後輸出收斂時的參數估計結果。

包含主函數的配置參數解析類如下:

[java]   view plain copy

  1. package liuyang.nlp.lda.main;  
  2. import java.io.File;  
  3. import java.io.IOException;  
  4. import java.util.ArrayList;  
  5. import liuyang.nlp.lda.com.FileUtil;  
  6. import liuyang.nlp.lda.conf.ConstantConfig;  
  7. import liuyang.nlp.lda.conf.PathConfig;  
  8. public class LdaGibbsSampling {  
  9.     public static class modelparameters {  
  10.         float alpha = 0.5f; //usual value is 50 / K  
  11.         float beta = 0.1f;//usual value is 0.1  
  12.         int topicNum = 100;  
  13.         int iteration = 100;  
  14.         int saveStep = 10;  
  15.         int beginSaveIters = 50;  
  16.     }  
  17.     private static void getParametersFromFile(modelparameters ldaparameters,  
  18.             String parameterFile) {  
  19.         // TODO Auto-generated method stub  
  20.         ArrayList<String> paramLines = new ArrayList<String>();  
  21.         FileUtil.readLines(parameterFile, paramLines);  
  22.         for(String line : paramLines){  
  23.             String[] lineParts = line.split("\t");  
  24.             switch(parameters.valueOf(lineParts[0])){  
  25.             case alpha:  
  26.                 ldaparameters.alpha = Float.valueOf(lineParts[1]);  
  27.                 break;  
  28.             case beta:  
  29.                 ldaparameters.beta = Float.valueOf(lineParts[1]);  
  30.                 break;  
  31.             case topicNum:  
  32.                 ldaparameters.topicNum = Integer.valueOf(lineParts[1]);  
  33.                 break;  
  34.             case iteration:  
  35.                 ldaparameters.iteration = Integer.valueOf(lineParts[1]);  
  36.                 break;  
  37.             case saveStep:  
  38.                 ldaparameters.saveStep = Integer.valueOf(lineParts[1]);  
  39.                 break;  
  40.             case beginSaveIters:  
  41.                 ldaparameters.beginSaveIters = Integer.valueOf(lineParts[1]);  
  42.                 break;  
  43.             }  
  44.         }  
  45.     }  
  46.     public enum parameters{  
  47.         alpha, beta, topicNum, iteration, saveStep, beginSaveIters;  
  48.     }  
  49.     public static void main(String[] args) throws IOException {  
  50.         // TODO Auto-generated method stub  
  51.         String originalDocsPath = PathConfig.ldaDocsPath;  
  52.         String resultPath = PathConfig.LdaResultsPath;  
  53.         String parameterFile= ConstantConfig.LDAPARAMETERFILE;  
  54.         modelparameters ldaparameters = new modelparameters();  
  55.         getParametersFromFile(ldaparameters, parameterFile);  
  56.         Documents docSet = new Documents();  
  57.         docSet.readDocs(originalDocsPath);  
  58.         System.out.println("wordMap size " + docSet.termToIndexMap.size());  
  59.         FileUtil.mkdir(new File(resultPath));  
  60.         LdaModel model = new LdaModel(ldaparameters);  
  61.         System.out.println("1 Initialize the model ...");  
  62.         model.initializeModel(docSet);  
  63.         System.out.println("2 Learning and Saving the model ...");  
  64.         model.inferenceModel(docSet);  
  65.         System.out.println("3 Output the final model ...");  
  66.         model.saveIteratedModel(ldaparameters.iteration, docSet);  
  67.         System.out.println("Done!");  
  68.     }  
  69. }  

LDA 模型實作類如下

[java]   view plain copy

  1. package liuyang.nlp.lda.main;  
  2. import java.io.BufferedWriter;  
  3. import java.io.FileWriter;  
  4. import java.io.IOException;  
  5. import java.util.ArrayList;  
  6. import java.util.Collections;  
  7. import java.util.Comparator;  
  8. import java.util.List;  
  9. import liuyang.nlp.lda.com.FileUtil;  
  10. import liuyang.nlp.lda.conf.PathConfig;  
  11. public class LdaModel {  
  12.     int [][] doc;//word index array  
  13.     int V, K, M;//vocabulary size, topic number, document number  
  14.     int [][] z;//topic label array  
  15.     float alpha; //doc-topic dirichlet prior parameter   
  16.     float beta; //topic-word dirichlet prior parameter  
  17.     int [][] nmk;//given document m, count times of topic k. M*K  
  18.     int [][] nkt;//given topic k, count times of term t. K*V  
  19.     int [] nmkSum;//Sum for each row in nmk  
  20.     int [] nktSum;//Sum for each row in nkt  
  21.     double [][] phi;//Parameters for topic-word distribution K*V  
  22.     double [][] theta;//Parameters for doc-topic distribution M*K  
  23.     int iterations;//Times of iterations  
  24.     int saveStep;//The number of iterations between two saving  
  25.     int beginSaveIters;//Begin save model at this iteration  
  26.     public LdaModel(LdaGibbsSampling.modelparameters modelparam) {  
  27.         // TODO Auto-generated constructor stub  
  28.         alpha = modelparam.alpha;  
  29.         beta = modelparam.beta;  
  30.         iterations = modelparam.iteration;  
  31.         K = modelparam.topicNum;  
  32.         saveStep = modelparam.saveStep;  
  33.         beginSaveIters = modelparam.beginSaveIters;  
  34.     }  
  35.     public void initializeModel(Documents docSet) {  
  36.         // TODO Auto-generated method stub  
  37.         M = docSet.docs.size();  
  38.         V = docSet.termToIndexMap.size();  
  39.         nmk = new int [M][K];  
  40.         nkt = new int[K][V];  
  41.         nmkSum = new int[M];  
  42.         nktSum = new int[K];  
  43.         phi = new double[K][V];  
  44.         theta = new double[M][K];  
  45.         //initialize documents index array  
  46.         doc = new int[M][];  
  47.         for(int m = 0; m < M; m++){  
  48.             //Notice the limit of memory  
  49.             int N = docSet.docs.get(m).docWords.length;  
  50.             doc[m] = new int[N];  
  51.             for(int n = 0; n < N; n++){  
  52.                 doc[m][n] = docSet.docs.get(m).docWords[n];  
  53.             }  
  54.         }  
  55.         //initialize topic lable z for each word  
  56.         z = new int[M][];  
  57.         for(int m = 0; m < M; m++){  
  58.             int N = docSet.docs.get(m).docWords.length;  
  59.             z[m] = new int[N];  
  60.             for(int n = 0; n < N; n++){  
  61.                 int initTopic = (int)(Math.random() * K);// From 0 to K - 1  
  62.                 z[m][n] = initTopic;  
  63.                 //number of words in doc m assigned to topic initTopic add 1  
  64.                 nmk[m][initTopic]++;  
  65.                 //number of terms doc[m][n] assigned to topic initTopic add 1  
  66.                 nkt[initTopic][doc[m][n]]++;  
  67.                 // total number of words assigned to topic initTopic add 1  
  68.                 nktSum[initTopic]++;  
  69.             }  
  70.              // total number of words in document m is N  
  71.             nmkSum[m] = N;  
  72.         }  
  73.     }  
  74.     public void inferenceModel(Documents docSet) throws IOException {  
  75.         // TODO Auto-generated method stub  
  76.         if(iterations < saveStep + beginSaveIters){  
  77.             System.err.println("Error: the number of iterations should be larger than " + (saveStep + beginSaveIters));  
  78.             System.exit(0);  
  79.         }  
  80.         for(int i = 0; i < iterations; i++){  
  81.             System.out.println("Iteration " + i);  
  82.             if((i >= beginSaveIters) && (((i - beginSaveIters) % saveStep) == 0)){  
  83.                 //Saving the model  
  84.                 System.out.println("Saving model at iteration " + i +" ... ");  
  85.                 //Firstly update parameters  
  86.                 updateEstimatedParameters();  
  87.                 //Secondly print model variables  
  88.                 saveIteratedModel(i, docSet);  
  89.             }  
  90.             //Use Gibbs Sampling to update z[][]  
  91.             for(int m = 0; m < M; m++){  
  92.                 int N = docSet.docs.get(m).docWords.length;  
  93.                 for(int n = 0; n < N; n++){  
  94.                     // Sample from p(z_i|z_-i, w)  
  95.                     int newTopic = sampleTopicZ(m, n);  
  96.                     z[m][n] = newTopic;  
  97.                 }  
  98.             }  
  99.         }  
  100.     }  
  101.     private void updateEstimatedParameters() {  
  102.         // TODO Auto-generated method stub  
  103.         for(int k = 0; k < K; k++){  
  104.             for(int t = 0; t < V; t++){  
  105.                 phi[k][t] = (nkt[k][t] + beta) / (nktSum[k] + V * beta);  
  106.             }  
  107.         }  
  108.         for(int m = 0; m < M; m++){  
  109.             for(int k = 0; k < K; k++){  
  110.                 theta[m][k] = (nmk[m][k] + alpha) / (nmkSum[m] + K * alpha);  
  111.             }  
  112.         }  
  113.     }  
  114.     private int sampleTopicZ(int m, int n) {  
  115.         // TODO Auto-generated method stub  
  116.         // Sample from p(z_i|z_-i, w) using Gibbs upde rule  
  117.         //Remove topic label for w_{m,n}  
  118.         int oldTopic = z[m][n];  
  119.         nmk[m][oldTopic]--;  
  120.         nkt[oldTopic][doc[m][n]]--;  
  121.         nmkSum[m]--;  
  122.         nktSum[oldTopic]--;  
  123.         //Compute p(z_i = k|z_-i, w)  
  124.         double [] p = new double[K];  
  125.         for(int k = 0; k < K; k++){  
  126.             p[k] = (nkt[k][doc[m][n]] + beta) / (nktSum[k] + V * beta) * (nmk[m][k] + alpha) / (nmkSum[m] + K * alpha);  
  127.         }  
  128.         //Sample a new topic label for w_{m, n} like roulette  
  129.         //Compute cumulated probability for p  
  130.         for(int k = 1; k < K; k++){  
  131.             p[k] += p[k - 1];  
  132.         }  
  133.         double u = Math.random() * p[K - 1]; //p[] is unnormalised  
  134.         int newTopic;  
  135.         for(newTopic = 0; newTopic < K; newTopic++){  
  136.             if(u < p[newTopic]){  
  137.                 break;  
  138.             }  
  139.         }  
  140.         //Add new topic label for w_{m, n}  
  141.         nmk[m][newTopic]++;  
  142.         nkt[newTopic][doc[m][n]]++;  
  143.         nmkSum[m]++;  
  144.         nktSum[newTopic]++;  
  145.         return newTopic;  
  146.     }  
  147.     public void saveIteratedModel(int iters, Documents docSet) throws IOException {  
  148.         // TODO Auto-generated method stub  
  149.         //lda.params lda.phi lda.theta lda.tassign lda.twords  
  150.         //lda.params  
  151.         String resPath = PathConfig.LdaResultsPath;  
  152.         String modelName = "lda_" + iters;  
  153.         ArrayList<String> lines = new ArrayList<String>();  
  154.         lines.add("alpha = " + alpha);  
  155.         lines.add("beta = " + beta);  
  156.         lines.add("topicNum = " + K);  
  157.         lines.add("docNum = " + M);  
  158.         lines.add("termNum = " + V);  
  159.         lines.add("iterations = " + iterations);  
  160.         lines.add("saveStep = " + saveStep);  
  161.         lines.add("beginSaveIters = " + beginSaveIters);  
  162.         FileUtil.writeLines(resPath + modelName + ".params", lines);  
  163.         //lda.phi K*V  
  164.         BufferedWriter writer = new BufferedWriter(new FileWriter(resPath + modelName + ".phi"));         
  165.         for (int i = 0; i < K; i++){  
  166.             for (int j = 0; j < V; j++){  
  167.                 writer.write(phi[i][j] + "\t");  
  168.             }  
  169.             writer.write("\n");  
  170.         }  
  171.         writer.close();  
  172.         //lda.theta M*K  
  173.         writer = new BufferedWriter(new FileWriter(resPath + modelName + ".theta"));  
  174.         for(int i = 0; i < M; i++){  
  175.             for(int j = 0; j < K; j++){  
  176.                 writer.write(theta[i][j] + "\t");  
  177.             }  
  178.             writer.write("\n");  
  179.         }  
  180.         writer.close();  
  181.         //lda.tassign  
  182.         writer = new BufferedWriter(new FileWriter(resPath + modelName + ".tassign"));  
  183.         for(int m = 0; m < M; m++){  
  184.             for(int n = 0; n < doc[m].length; n++){  
  185.                 writer.write(doc[m][n] + ":" + z[m][n] + "\t");  
  186.             }  
  187.             writer.write("\n");  
  188.         }  
  189.         writer.close();  
  190.         //lda.twords phi[][] K*V  
  191.         writer = new BufferedWriter(new FileWriter(resPath + modelName + ".twords"));  
  192.         int topNum = 20; //Find the top 20 topic words in each topic  
  193.         for(int i = 0; i < K; i++){  
  194.             List<Integer> tWordsIndexArray = new ArrayList<Integer>();   
  195.             for(int j = 0; j < V; j++){  
  196.                 tWordsIndexArray.add(new Integer(j));  
  197.             }  
  198.             Collections.sort(tWordsIndexArray, new LdaModel.TwordsComparable(phi[i]));  
  199.             writer.write("topic " + i + "\t:\t");  
  200.             for(int t = 0; t < topNum; t++){  
  201.                 writer.write(docSet.indexToTermMap.get(tWordsIndexArray.get(t)) + " " + phi[i][tWordsIndexArray.get(t)] + "\t");  
  202.             }  
  203.             writer.write("\n");  
  204.         }  
  205.         writer.close();  
  206.     }  
  207.     public class TwordsComparable implements Comparator<Integer> {  
  208.         public double [] sortProb; // Store probability of each word in topic k  
  209.         public TwordsComparable (double[] sortProb){  
  210.             this.sortProb = sortProb;  
  211.         }  
  212.         @Override  
  213.         public int compare(Integer o1, Integer o2) {  
  214.             // TODO Auto-generated method stub  
  215.             //Sort topic word index according to the probability of each word in topic k  
  216.             if(sortProb[o1] > sortProb[o2]) return -1;  
  217.             else if(sortProb[o1] < sortProb[o2]) return 1;  
  218.             else return 0;  
  219.         }  
  220.     }  
  221. }  

程式的實作細節可以參考我在程式中給出的注釋,如果了解LDA Gibbs Sampling的算法流程,上面的代碼很好了解。其實排除輸入輸出和參數解析的代碼,标準LDA 的Gibbs sampling隻需要不到200行程式就可以搞定。當然,裡面有很多可以考慮優化和變形的地方。

還有com和conf目錄下的源檔案分别放置常用函數和配置類,完整的JAVA工程見Github https://github.com/yangliuy/LDAGibbsSampling

3 用LDA Gibbs Sampling對Newsgroup 18828文檔集進行主題分析

下面我們給出将上面的LDA Gibbs Sampling的實作Apply到Newsgroup 18828文檔集進行主題分析的結果。 我實驗時用到的資料已經上傳到Github中,感興趣的朋友可以直接從Github中下載下傳工程運作。 我在Newsgroup 18828文檔集随機選擇了9個目錄,每個目錄下選擇一個文檔,将它們放置在data\LdaOriginalDocs目錄下,我設定的模型參數如下

[plain]   view plain copy

  1. alpha   0.5  
  2. beta    0.1  
  3. topicNum    10  
  4. iteration   100  
  5. saveStep    10  
  6. beginSaveIters  80  

即設定alpha和beta的值為0.5和0.1, Topic數目為10,疊代100次,從第80次開始儲存模型結果,每10次儲存一次。

經過100次Gibbs Sampling疊代後,程式輸出10個Topic下top的topic words以及對應的機率值如下

幾種機率語言模型和參數學習方法
幾種機率語言模型和參數學習方法

我們可以看到雖然是unsupervised learning, LDA分析出來的Topic words還是非常make sense的。比如第5個topic是宗教類的,第6個topic是天文類的,第7個topic是計算機類的。程式的輸出還包括模型參數.param檔案,topic-word分布phi向量.phi檔案,doc-topic分布theta向量.theta檔案以及每個文檔中每個單詞配置設定到的主題label的.tassign檔案。感興趣的朋友可以從Github https://github.com/yangliuy/LDAGibbsSampling 下載下傳完整工程自己換用其他資料集進行主題分析實驗。 本程式是初步實作版本,如果大家發現任何問題或者bug歡迎交流,我第一時間在Github修複bug更新版本。

4 參考文獻

[1] Christopher M. Bishop. Pattern Recognition and Machine Learning (Information Science and Statistics). Springer-Verlag New York, Inc., Secaucus, NJ, USA, 2006.

[2] Gregor Heinrich. Parameter estimation for text analysis. Technical report, 2004.

[3] Wang Yi. Distributed Gibbs Sampling of Latent Topic Models: The Gritty Details Technical report, 2005.

[4] Wayne Xin Zhao, Note for pLSA and LDA, Technical report, 2011.

[5] Freddy Chong Tat Chua. Dimensionality reduction and clustering of text documents.Technical report, 2009.

[6] Jgibblda, http://jgibblda.sourceforge.net/

[7]David M. Blei, Andrew Y. Ng, and Michael I. Jordan. 2003. Latent dirichlet allocation. J. Mach. Learn. Res. 3 (March 2003), 993-1022.