天天看點

batch & print pro_圖卷積網絡(GCN) Mini-Batch技巧

引言

圖卷積網絡GCN取得了非常好的效果。

對于小型圖,可以全部加載到記憶體中,每層Layer的卷積操作都會周遊全圖。

對于中大型圖,全部加載到記憶體的做法,顯然不能滿足需求。我們會使用mini-batch而不是全圖來進行計算。

下面将介紹三種目前常見的Batch技巧,分别來自GraphSage和ScalableGCN。

1. GraphSage Batch技巧

batch & print pro_圖卷積網絡(GCN) Mini-Batch技巧

如上圖所示,h0是模型inputs資料。3層圖卷積Layers,分别是h1, h2, h3。

第k層

batch & print pro_圖卷積網絡(GCN) Mini-Batch技巧

中的每個點對應的

batch & print pro_圖卷積網絡(GCN) Mini-Batch技巧

通過下面公式來更新

batch & print pro_圖卷積網絡(GCN) Mini-Batch技巧
batch & print pro_圖卷積網絡(GCN) Mini-Batch技巧

先通過AGGREGATE算子将鄰居節點的embedding聚集起來,再和節點自身的embedding資訊combine,用輸出結果來更新自身embedding。

batch & print pro_圖卷積網絡(GCN) Mini-Batch技巧

我們可以觀察到,為了計算最後一層圖卷積Layer h3的某個點的embedding,隻需要h3->h2->h1->h0中涉及到的點(上圖中藍色的點)。相對于全部的點(灰色+藍色的點),這個計算範圍已經大大縮小了。考慮到有些點可能關聯非常多的鄰接點,如果使用全部鄰接點來計算AGGREGATE,那麼計算複雜度将會不可控。

為了解決這個問題,GrapheSage提出抽樣鄰接點的方法。模型限制每個節點抽樣鄰接點的數量m,m是一個可配置的超參數。如果m值越大,那麼模型計算越準确,但是計算代價也越大。反之m值越小,模型計算越不準确,但是計算代價也越小。

具體采樣的方式是:

  • 當鄰接點數量小于m時,通過重複采樣的方式補齊到m個
  • 當鄰接點數量大于m時,從中随機抽取m個

GraphSage采取Uniform采樣方式,是基于所有鄰接點的重要性是一樣的假設。這個假設在很多實際問題中是不成立的,之後的論文裡通過使用Importance Sampling或者引入Attention機制來進行優化。

使用鄰接點數為m的采樣之後,對于L層的圖卷積網絡,一個mini-batch中就隻會涉及到

batch & print pro_圖卷積網絡(GCN) Mini-Batch技巧

個節點,避免了對全圖的周遊。

2. ScalableGCN Batch技巧

ScalableGCN是阿裡媽媽提出的一種在大規模圖上加速Mini-batch GCN訓練速度方法。這個方法也包含在最近開源的大規模分布式圖學習架構Euler裡面。

ScalableGCN的官方介紹連結是

alibaba/euler​github.com

batch & print pro_圖卷積網絡(GCN) Mini-Batch技巧

這個wiki裡面方法介紹的比較粗略,下面我會結合實作代碼細節,具體介紹ScalableGCN算法。

batch & print pro_圖卷積網絡(GCN) Mini-Batch技巧

圖3,取自https://github.com/alibaba/euler/wiki/ScalableGCN

在第一部分裡面,我們介紹了GraphSage采用的batch方式,雖然相比于全圖計算,已經将計算量大幅減少,但是可以看到它的複雜度還是和卷積層數L成指數級關系。

圖1

中低階的embedding(比如h0, h1, h2)會在計算不同點的高階embedding(h3)時,被大量重複使用。如果我們把這部分低階的embedding緩存起來,就可以減少除了自己節點計算以外,其他鄰接點的計算。

具體地,對于

batch & print pro_圖卷積網絡(GCN) Mini-Batch技巧

層GCN模型,開辟存儲空間:

batch & print pro_圖卷積網絡(GCN) Mini-Batch技巧

,将mini-batch SGD中頂點最新的前

batch & print pro_圖卷積網絡(GCN) Mini-Batch技巧

層的embedding存儲起來。同時,我們修改GCN模型為:

batch & print pro_圖卷積網絡(GCN) Mini-Batch技巧
batch & print pro_圖卷積網絡(GCN) Mini-Batch技巧

即在彙聚的時候使用緩存中的embedding值,這樣一來我們隻需計算mini-batch中的樣本頂點的卷積結果,無需對擴散後的

batch & print pro_圖卷積網絡(GCN) Mini-Batch技巧

階所有鄰接頂點進行卷積計算。我們用中心頂點(圖3中藍色節點)的embedding

batch & print pro_圖卷積網絡(GCN) Mini-Batch技巧

更新

batch & print pro_圖卷積網絡(GCN) Mini-Batch技巧

self
           

源碼中使用trainable=False的tf variable來存儲前

batch & print pro_圖卷積網絡(GCN) Mini-Batch技巧

層的embedding。trainable設為False是因為這部分緩存的embedding不需要參與反向求導。

for 
           

當頂層node_embedding計算更新之後,我們用這個更新之後的embedding值來更新stores裡面對應節點的緩存。

下面介紹模型的反向求導過程

batch & print pro_圖卷積網絡(GCN) Mini-Batch技巧

圖4,取自https://github.com/alibaba/euler/wiki/ScalableGCN

我們開辟存儲空間:

batch & print pro_圖卷積網絡(GCN) Mini-Batch技巧

來存儲前

batch & print pro_圖卷積網絡(GCN) Mini-Batch技巧

層緩存embedding的導數。

self
           

類似的源碼中使用trainable=False的tf variable來存儲這部分導數。

因為前

batch & print pro_圖卷積網絡(GCN) Mini-Batch技巧

層緩存embedding是trainable為False的variable,是以我們需要手動通過tf.gradients來計算對應的導數,并存到graident_stores緩存裡面。源碼實作如下

for 
           

下面我們需要通過Back-Propagation來更新模型參數

batch & print pro_圖卷積網絡(GCN) Mini-Batch技巧

如果沒有因為緩存資料結構,更新模型參數很簡單,隻需要執行optimizer.minimize(loss),tensorflow就會自動更新參數

batch & print pro_圖卷積網絡(GCN) Mini-Batch技巧

但是這裡我們為了避免重複計算引入了緩存結構(trainable=False的variable),這些variable的導數是我們手動掉tf.gradient接口計算出來的。根據鍊式求導法則

batch & print pro_圖卷積網絡(GCN) Mini-Batch技巧

其中

batch & print pro_圖卷積網絡(GCN) Mini-Batch技巧

就是graident_stores緩存

batch & print pro_圖卷積網絡(GCN) Mini-Batch技巧

。為了更新

batch & print pro_圖卷積網絡(GCN) Mini-Batch技巧

,我們隻需要将loss改為

batch & print pro_圖卷積網絡(GCN) Mini-Batch技巧

,再通過optimizer.minimize就可以更新參數

batch & print pro_圖卷積網絡(GCN) Mini-Batch技巧

了。具體實作如下

def