引言
圖卷積網絡GCN取得了非常好的效果。
對于小型圖,可以全部加載到記憶體中,每層Layer的卷積操作都會周遊全圖。
對于中大型圖,全部加載到記憶體的做法,顯然不能滿足需求。我們會使用mini-batch而不是全圖來進行計算。
下面将介紹三種目前常見的Batch技巧,分别來自GraphSage和ScalableGCN。
1. GraphSage Batch技巧
如上圖所示,h0是模型inputs資料。3層圖卷積Layers,分别是h1, h2, h3。
第k層
中的每個點對應的
通過下面公式來更新
先通過AGGREGATE算子将鄰居節點的embedding聚集起來,再和節點自身的embedding資訊combine,用輸出結果來更新自身embedding。
我們可以觀察到,為了計算最後一層圖卷積Layer h3的某個點的embedding,隻需要h3->h2->h1->h0中涉及到的點(上圖中藍色的點)。相對于全部的點(灰色+藍色的點),這個計算範圍已經大大縮小了。考慮到有些點可能關聯非常多的鄰接點,如果使用全部鄰接點來計算AGGREGATE,那麼計算複雜度将會不可控。
為了解決這個問題,GrapheSage提出抽樣鄰接點的方法。模型限制每個節點抽樣鄰接點的數量m,m是一個可配置的超參數。如果m值越大,那麼模型計算越準确,但是計算代價也越大。反之m值越小,模型計算越不準确,但是計算代價也越小。
具體采樣的方式是:
- 當鄰接點數量小于m時,通過重複采樣的方式補齊到m個
- 當鄰接點數量大于m時,從中随機抽取m個
GraphSage采取Uniform采樣方式,是基于所有鄰接點的重要性是一樣的假設。這個假設在很多實際問題中是不成立的,之後的論文裡通過使用Importance Sampling或者引入Attention機制來進行優化。
使用鄰接點數為m的采樣之後,對于L層的圖卷積網絡,一個mini-batch中就隻會涉及到
個節點,避免了對全圖的周遊。
2. ScalableGCN Batch技巧
ScalableGCN是阿裡媽媽提出的一種在大規模圖上加速Mini-batch GCN訓練速度方法。這個方法也包含在最近開源的大規模分布式圖學習架構Euler裡面。
ScalableGCN的官方介紹連結是
alibaba/eulergithub.com
這個wiki裡面方法介紹的比較粗略,下面我會結合實作代碼細節,具體介紹ScalableGCN算法。
圖3,取自https://github.com/alibaba/euler/wiki/ScalableGCN
在第一部分裡面,我們介紹了GraphSage采用的batch方式,雖然相比于全圖計算,已經将計算量大幅減少,但是可以看到它的複雜度還是和卷積層數L成指數級關系。
在
圖1中低階的embedding(比如h0, h1, h2)會在計算不同點的高階embedding(h3)時,被大量重複使用。如果我們把這部分低階的embedding緩存起來,就可以減少除了自己節點計算以外,其他鄰接點的計算。
具體地,對于
層GCN模型,開辟存儲空間:
,将mini-batch SGD中頂點最新的前
層的embedding存儲起來。同時,我們修改GCN模型為:
即在彙聚的時候使用緩存中的embedding值,這樣一來我們隻需計算mini-batch中的樣本頂點的卷積結果,無需對擴散後的
階所有鄰接頂點進行卷積計算。我們用中心頂點(圖3中藍色節點)的embedding
更新
。
self
源碼中使用trainable=False的tf variable來存儲前
層的embedding。trainable設為False是因為這部分緩存的embedding不需要參與反向求導。
for
當頂層node_embedding計算更新之後,我們用這個更新之後的embedding值來更新stores裡面對應節點的緩存。
下面介紹模型的反向求導過程
圖4,取自https://github.com/alibaba/euler/wiki/ScalableGCN
我們開辟存儲空間:
來存儲前
層緩存embedding的導數。
self
類似的源碼中使用trainable=False的tf variable來存儲這部分導數。
因為前
層緩存embedding是trainable為False的variable,是以我們需要手動通過tf.gradients來計算對應的導數,并存到graident_stores緩存裡面。源碼實作如下
for
下面我們需要通過Back-Propagation來更新模型參數
。
如果沒有因為緩存資料結構,更新模型參數很簡單,隻需要執行optimizer.minimize(loss),tensorflow就會自動更新參數
。
但是這裡我們為了避免重複計算引入了緩存結構(trainable=False的variable),這些variable的導數是我們手動掉tf.gradient接口計算出來的。根據鍊式求導法則
其中
就是graident_stores緩存
。為了更新
,我們隻需要将loss改為
,再通過optimizer.minimize就可以更新參數
了。具體實作如下
def