天天看點

【論文筆記】FM: Factorization Machines

本文記錄因子分析機FM算法的推導和了解筆記

文章目錄

          • 論文位址
          • 二階FM推導過程
          • 二階FM反向傳播
          • 多階FM
論文位址

https://www.csie.ntu.edu.tw/~b97053/paper/Rendle2010FM.pdf

二階FM推導過程

FM在預測任務是考慮了不同特征之間的交叉情況, 以2階的交叉為例:

y ^ ( x ) = w 0 + ∑ i = 1 n w i ∗ x i + ∑ i = 1 n ∑ j = i + 1 n W x i x j (1) \hat{y}(x)=w_0+\sum_{i=1}^{n}w_i*x_i+\sum_{i=1}^{n}\sum_{j=i+1}^{n}Wx_ix_j \tag{1} y^​(x)=w0​+i=1∑n​wi​∗xi​+i=1∑n​j=i+1∑n​Wxi​xj​(1)

其中的 w 0 w_0 w0​, w i w_i wi​, W W W是模型需要學習的内容。由于在實際場景中, x i x_i xi​, x j x_j xj​都是次元很大并且稀疏的one-hot類型的向量,如果直接學習交叉項的權重 W W W很容易過拟合。

但是注意到 W W W應該是一個實對稱的矩陣,由實對稱矩陣理論的性質:

每個實對稱矩陣 A A A可以分解成這樣一種形式: A = Q Λ Q T A=Q\Lambda Q^T A=QΛQT ,其中 Λ \Lambda Λ為對角陣, Q Q Q為正交矩陣

進而 W W W可以被分解成 W = V V T W=VV^T W=VVT,其中 V ∈ R n × k V \in R^{n \times k} V∈Rn×k,是以式子(1)可以化成: y ^ ( x ) = w 0 + ∑ i = 1 n w i ∗ x i + ∑ i = 1 n ∑ j = i + 1 n ⟨ v i , v j ⟩ x i x j (2) \hat{y}(x)=w_0+\sum_{i=1}^{n}w_i*x_i+\sum_{i=1}^{n}\sum_{j=i+1}^{n} \langle v_i, v_j \rangle x_ix_j \tag{2} y^​(x)=w0​+i=1∑n​wi​∗xi​+i=1∑n​j=i+1∑n​⟨vi​,vj​⟩xi​xj​(2)

v i v_i vi​和 v j v_j vj​可以用長度為 k k k的向量表示: ⟨ v i , v j ⟩ = ∑ f = 1 k v i , f ⋅ v j , f \langle v_i, v_j \rangle = \sum_{f=1}^{k}v_{i,f} \cdot v_{j,f} ⟨vi​,vj​⟩=∑f=1k​vi,f​⋅vj,f​

是以有:

y ^ ( x ) = w 0 + ∑ i = 1 n w i ∗ x i + ∑ i = 1 n ∑ j = i + 1 n ∑ f = 1 k v i , f ⋅ v j , f x i x j (3) \hat{y}(x)=w_0+\sum_{i=1}^{n}w_i*x_i+\sum_{i=1}^{n}\sum_{j=i+1}^{n}\sum_{f=1}^{k}v_{i,f} \cdot v_{j,f}x_ix_j \tag{3} y^​(x)=w0​+i=1∑n​wi​∗xi​+i=1∑n​j=i+1∑n​f=1∑k​vi,f​⋅vj,f​xi​xj​(3)

直接求解這個算法的時間複雜度為 O ( k n 2 ) O(kn^2) O(kn2),但是可以通過調整求解方式将複雜度降為 O ( k n ) O(kn) O(kn)

令 M = ∑ i = 1 n ∑ j = i + 1 n ∑ f = 1 k v i , f v j , f x i x j M=\sum_{i=1}^{n}\sum_{j=i+1}^{n}\sum_{f=1}^{k}v_{i,f}v_{j,f}x_ix_j M=∑i=1n​∑j=i+1n​∑f=1k​vi,f​vj,f​xi​xj​

記 N = ∑ i = 1 n ∑ j = 1 n ∑ f = 1 k v i , f v j , f x i x j N=\sum_{i=1}^{n}\sum_{j=1}^{n}\sum_{f=1}^{k}v_{i,f}v_{j,f}x_ix_j N=∑i=1n​∑j=1n​∑f=1k​vi,f​vj,f​xi​xj​

由于:

N = ∑ i = 1 n ∑ j = 1 n ∑ f = 1 k v i , f v j , f x i x j = ∑ i = 1 n ∑ f = 1 k ( ∑ j = 1 i − 1 v i , f v j , f x i x j + ∑ j = i i v i , f v j , f x i x j + ∑ j = i + 1 n v i , f v j , f x i x j ) = ∑ i = 1 n ∑ f = 1 k ( 2 ∑ j = i + 1 n v i , f v j , f x i x j + v i , f v i , f x i x i ) = 2 ∑ i = 1 n ∑ f = 1 k ∑ j = i + 1 n v i , f v j , f x i x j + ∑ i = 1 n ∑ f = 1 k v i , f v i , f x i x i = 2 M + ∑ i = 1 n ∑ f = 1 k v i , f v i , f x i x i (4) \begin{aligned} N= & \sum_{i=1}^{n}\sum_{j=1}^{n}\sum_{f=1}^{k}v_{i,f}v_{j,f}x_ix_j \\ = & \sum_{i=1}^{n}\sum_{f=1}^{k}(\sum_{j=1}^{i-1} v_{i,f}v_{j,f}x_ix_j + \sum_{j=i}^{i} v_{i,f}v_{j,f}x_ix_j + \sum_{j=i+1}^{n} v_{i,f}v_{j,f}x_ix_j ) \\ = & \sum_{i=1}^{n}\sum_{f=1}^{k}(2\sum_{j=i+1}^{n}v_{i,f}v_{j,f}x_ix_j+ v_{i,f}v_{i,f}x_ix_i ) \\ = & 2 \sum_{i=1}^{n}\sum_{f=1}^{k}\sum_{j=i+1}^{n}v_{i,f}v_{j,f}x_ix_j + \sum_{i=1}^{n}\sum_{f=1}^{k} v_{i,f}v_{i,f}x_ix_i \\ = & 2M+ \sum_{i=1}^{n}\sum_{f=1}^{k} v_{i,f}v_{i,f}x_ix_i \tag{4} \end{aligned} N=====​i=1∑n​j=1∑n​f=1∑k​vi,f​vj,f​xi​xj​i=1∑n​f=1∑k​(j=1∑i−1​vi,f​vj,f​xi​xj​+j=i∑i​vi,f​vj,f​xi​xj​+j=i+1∑n​vi,f​vj,f​xi​xj​)i=1∑n​f=1∑k​(2j=i+1∑n​vi,f​vj,f​xi​xj​+vi,f​vi,f​xi​xi​)2i=1∑n​f=1∑k​j=i+1∑n​vi,f​vj,f​xi​xj​+i=1∑n​f=1∑k​vi,f​vi,f​xi​xi​2M+i=1∑n​f=1∑k​vi,f​vi,f​xi​xi​​(4)

是以有:

M = ( N − ∑ i = 1 n ∑ f = 1 k v i , f v i , f x i x i ) / 2 = 1 2 ∑ i = 1 n ∑ j = 1 n ∑ f = 1 k v i , f v j , f x i x j − 1 2 ∑ i = 1 n ∑ f = 1 k v i , f v i , f x i x i = 1 2 ( ∑ i = 1 n ∑ f = 1 k v i , f x i ) ( ∑ j = 1 n ∑ f = 1 k v j , f x j ) − 1 2 ∑ i = 1 n ∑ f = 1 k v i , f 2 x i 2 = 1 2 ( ∑ i = 1 n ∑ f = 1 k v i , f x i ) 2 − 1 2 ∑ i = 1 n ∑ f = 1 k ( v i , f x i ) 2 (5) \begin{aligned} M & = (N- \sum_{i=1}^{n}\sum_{f=1}^{k} v_{i,f}v_{i,f}x_ix_i)/2 \\ & =\frac{1}{2} \sum_{i=1}^{n}\sum_{j=1}^{n}\sum_{f=1}^{k}v_{i,f}v_{j,f}x_ix_j - \frac{1}{2} \sum_{i=1}^{n}\sum_{f=1}^{k} v_{i,f}v_{i,f}x_ix_i \\ &=\frac{1}{2} (\sum_{i=1}^{n}\sum_{f=1}^{k}v_{i,f}x_i)(\sum_{j=1}^{n}\sum_{f=1}^{k}v_{j,f}x_j) - \frac{1}{2} \sum_{i=1}^{n}\sum_{f=1}^{k} v_{i,f}^{2}x_i^2\\ &=\frac{1}{2} (\sum_{i=1}^{n}\sum_{f=1}^{k}v_{i,f}x_i)^2-\frac{1}{2} \sum_{i=1}^{n}\sum_{f=1}^{k}(v_{i,f}x_i)^2 \tag{5} \end{aligned} M​=(N−i=1∑n​f=1∑k​vi,f​vi,f​xi​xi​)/2=21​i=1∑n​j=1∑n​f=1∑k​vi,f​vj,f​xi​xj​−21​i=1∑n​f=1∑k​vi,f​vi,f​xi​xi​=21​(i=1∑n​f=1∑k​vi,f​xi​)(j=1∑n​f=1∑k​vj,f​xj​)−21​i=1∑n​f=1∑k​vi,f2​xi2​=21​(i=1∑n​f=1∑k​vi,f​xi​)2−21​i=1∑n​f=1∑k​(vi,f​xi​)2​(5)

是以(3)式可以轉化為:

y ^ ( x ) = w 0 + ∑ i = 1 n w i ∗ x i + 1 2 ( ∑ i = 1 n ∑ f = 1 k v i , f x i ) 2 − 1 2 ∑ i = 1 n ∑ f = 1 k ( v i , f x i ) 2 (6) \hat{y}(x)=w_0+\sum_{i=1}^{n}w_i*x_i+\frac{1}{2} (\sum_{i=1}^{n}\sum_{f=1}^{k}v_{i,f}x_i)^2- \frac{1}{2} \sum_{i=1}^{n}\sum_{f=1}^{k}(v_{i,f}x_i)^2 \tag{6} y^​(x)=w0​+i=1∑n​wi​∗xi​+21​(i=1∑n​f=1∑k​vi,f​xi​)2−21​i=1∑n​f=1∑k​(vi,f​xi​)2(6)

求解上面表達式所需要的時間複雜度為 O ( k n ) O(kn) O(kn),由于 k ≪ n k \ll n k≪n且為常數,是以為線性複雜度。

二階FM反向傳播

在式(6)我們要求解的為模型的權重 w 0 w_0 w0​, w i w_i wi​, v i , f v_{i,f} vi,f​

對 w 0 w_0 w0​求導: ∂ y ^ ∂ w 0 = 1 \frac{\partial \hat y}{ \partial w_0} =1 ∂w0​∂y^​​=1

對 w i w_i wi​求導: ∂ y ^ ∂ w i = x i \frac{\partial \hat y}{\partial w_i} =x_i ∂wi​∂y^​​=xi​

對 v i , f v_{i,f} vi,f​求導: ∂ y ^ ∂ v i , f = ( ∑ i = 1 n ∑ f = 1 k v i , f x i ) x i − v i , f x i ⋅ x i \frac{\partial \hat y}{\partial v_{i,f}} =(\sum_{i=1}^{n}\sum_{f=1}^{k}v_{i,f}x_i)x_i - v_{i,f}x_i \cdot x_i ∂vi,f​∂y^​​=(∑i=1n​∑f=1k​vi,f​xi​)xi​−vi,f​xi​⋅xi​

多階FM

設特征直接互相交叉的類别數為d,那麼有:

y ^ ( x ) = w 0 + ∑ i = 1 n w i x i + ∑ l = 2 d ∑ i 1 = 1 n . . . ∑ i l = i l − 1 + 1 n ( ∏ j = 1 l x i j ) ( ∑ f = 1 k l ∏ j = 1 l v i j , f ( l ) ) \hat y(x)=w_0+\sum_{i=1}^{n}w_ix_i + \sum_{l=2}^{d}\sum_{i_1=1}^n...\sum_{i_l=i_{l-1}+1}^n(\prod_{j=1}^{l}x_{i_j})(\sum_{f=1}^{k_l}\prod_{j=1}^lv_{i_j,f}^{(l)}) y^​(x)=w0​+i=1∑n​wi​xi​+l=2∑d​i1​=1∑n​...il​=il−1​+1∑n​(j=1∏l​xij​​)(f=1∑kl​​j=1∏l​vij​,f(l)​)

直接求解的複雜度為 O ( k d n d ) O(kdn^d) O(kdnd),但是可以通過上面的方法近似降成線性複雜度。

繼續閱讀