天天看點

論文筆記《Neural Factorization Machines for Sparse Predictive Analytics》

原文連結

本文發表于資訊檢索領域頂級會議 SIGIR 2017

代碼連結

摘要

在當今網際網路工業界中,有許多預測任務需要用到大量的類别特征。要想将這些類别特征送入到模型中,就必須得将其onehot。但這樣一來,就會産生大量的稀疏特征,要想從這些稀疏特征中充分學習到有用的資訊,必須要考慮特征之間的互相作用。

FM算法是一種常用的解決方案,因為它充分考慮了二階特征之間的互相作用。然而FM有一個缺點,就是它僅僅以線性的方式組合了特征,并不能考慮到特征之間的非線性關系。

本文提出了一個稱為Neural Factorization Machine (NFM)的模型,來解決上述問題。NFM充分結合了FM提取的二階線性特征與神經網絡提取的高階非線性特征。總得來說,FM可以被看作一個沒有隐含層的NFM,故NFM肯定比FM更具表現力。實驗證明,NFM效果不錯。

模型

1.FM

假設我們有一個特征向量 x∈Rn x ∈ R n ,FM算法是通過對每一對特征之間做相乘來提取二階線性特征,公式如下:

論文筆記《Neural Factorization Machines for Sparse Predictive Analytics》

其中 w0 w 0 為全局的 bias b i a s , wi w i 則是控制每一個特征 xi x i 對預測結果影響的權重。 wij w i j 為二次項的權重。

但這樣會造成一個問題,由上式可知,二次項的個數為 n(n−1)2 n ( n − 1 ) 2 個,并且 xi x i xj x j 都是極度稀疏的向量,要有大量能滿足 xi x i xj x j 同時不為零的樣本才能夠對上式進行訓練(因為一旦有一個為0則相乘為0),這對資料的要求過于苛刻。又由于所有的 wij w i j 可以組成一個對稱矩陣 W W ,而我們可以将矩陣 WW 分解為 W=VTV W = V T V ,故我們可以通過下式來求一個近似解:

論文筆記《Neural Factorization Machines for Sparse Predictive Analytics》
論文筆記《Neural Factorization Machines for Sparse Predictive Analytics》
論文筆記《Neural Factorization Machines for Sparse Predictive Analytics》

2.NFM

假設我們有特征向量輸入 x∈Rn x ∈ R n 其中 xi=0 x i = 0 代表該樣本沒有第 i i 個特征。NFM的目标函數可以由下式表示:

論文筆記《Neural Factorization Machines for Sparse Predictive Analytics》

其中第1項與第2項是與FM相似的線性回歸部分,第3項是NFM的核心部分,它由一個如下圖所示的網絡結構組成:

論文筆記《Neural Factorization Machines for Sparse Predictive Analytics》

我們來逐層解釋上圖:

Embedding Layer

該層是一個全連接配接層,将稀疏的向量給壓縮表示。

假設我們有 vi∈Rkvi∈Rk 為第 i i 個特征的embedding向量,那麼在經過該層之後,我們得到的輸出為 {x1{x1 v1 v 1 ,...,xn , . . . , x n vn v n } } ,注意,該層本質上是一個全連接配接層,不是簡單的embedding lookup.

Bi-Interaction Layer

上層得到的輸出是一個特征向量的embedding的集合,本層本質上是做一個pooling的操作,讓這個embedding向量集合變為一個向量,公式如下:

論文筆記《Neural Factorization Machines for Sparse Predictive Analytics》

其中 ⨀ ⨀ 代表兩個向量對應的元素相乘。顯然,該層的輸出向量為 k k 維,本層采用的pooling方式與傳統的max pool和average pool一樣都是線性複雜度的,上式可以變換為:

論文筆記《Neural Factorization Machines for Sparse Predictive Analytics》

上式中用 v2v2 來表示 v⨀v v ⨀ v ,其實本層本質上就是一個fm算法。

hidden layer

就是普通的全連接配接層,沒有什麼特别的。

Prediction Layer

将hidden layer的輸出過一個n*1的全連接配接層,得到輸出

繼續閱讀