天天看點

【GNN】task6-基于圖神經網絡的圖表征學習方法

學習心得

  • 這次學習了基于圖同構網絡(GIN)的圖表征網絡。為了得到圖表征首先需要做節點表征,然後做圖讀出。GIN中節點表征的計算遵循WL Test算法中節點标簽的更新方法,是以它的上界是WL Test算法。

    在圖讀出中,我們對所有的節點表征(權重,如果用Attention的話)求和,這會造成節點分布資訊的丢失。

  • 為了研究圖神經網絡的表達力問題,産生一個重要模型——圖同構模型,​

    ​Weisfeiler-Lehman​

    ​測試就是檢測兩個圖是否在拓撲結構上圖同構的近似方法;該測試最大的特點是:對每個節點的子樹的聚合函數采用的是單射(​

    ​Injective​

    ​)的散列函數。

    ——由該特點我們可以通過設計一個單射聚合聚合函數來設計與WL一樣強大的圖卷積網絡(同時,圖同構網絡有強大的圖區分能力,适合圖分類任務)。

  • 具體還要讀相關論文,關注阿裡算法大佬的知乎專欄(tql):​​https://zhuanlan.zhihu.com/p/90645716​​

文章目錄

  • ​​學習心得​​
  • ​​引言​​
  • ​​一、基于圖同構網絡(GIN)的圖表征網絡的實作​​
  • ​​1.基于圖同構網絡的圖表征子產品(GINGraphRepr Module)​​
  • ​​圖表征子產品運作流程​​
  • ​​2.基于圖同構網絡的節點嵌入子產品(GINNodeEmbedding Module)​​
  • ​​3.GINConv--圖同構卷積層​​
  • ​​4.AtomEncoder 與 BondEncoder​​
  • ​​OGB簡介:​​
  • ​​二、How Powerful are Graph Neural Networks?​​
  • ​​1.Motivation​​
  • ​​2.文章内容​​
  • ​​3.背景:Weisfeiler-Lehman Test (WL Test)​​
  • ​​(1)圖同構性測試算法WL Test​​
  • ​​背景介紹​​
  • ​​WL舉例說明(以一維為栗子)​​
  • ​​第一步:聚合​​
  • ​​第二步:标簽散列(哈希)​​
  • ​​注:怎樣的聚合函數是一個單射函數?​​
  • ​​第三步:給節點重新打上标簽。​​
  • ​​第四步:數标簽​​
  • ​​第五步:判斷同構性​​
  • ​​(2)WL Subtree Kernel圖相似性評估(定量化)​​
  • ​​4.小結​​
  • ​​三、作業​​
  • ​​四、論文相關​​
  • ​​REFERENCE​​

引言

在此篇文章中我們将學習基于圖神經網絡的圖表征學習方法,圖表征學習要求根據節點屬性、邊和邊的屬性(如果有的話)生成一個向量作為圖的表征,基于圖表征我們可以做圖的預測。基于圖同構網絡(Graph Isomorphism Network, GIN)的圖表征網絡是目前最經典的圖表征學習網絡,我們将以它為例,通過該網絡的實作、項目實踐和理論分析,三個層面來學習基于圖神經網絡的圖表征學習方法。

提出圖同構網絡的論文:​​How Powerful are Graph Neural Networks? ​​

一、基于圖同構網絡(GIN)的圖表征網絡的實作

基于圖同構網絡的圖表征學習主要包含以下兩個過程:

  1. 首先計算得到節點表征;
  2. 其次對圖上各個節點的表征做圖池化(Graph Pooling),或稱為圖讀出(Graph Readout),得到圖的表征(Graph Representation)。

自頂向下的學習順序:GIN圖表征—>節點表征

1.基于圖同構網絡的圖表征子產品(GINGraphRepr Module)

GIN圖同構網絡模型的建構

  • 能實作判斷圖同構性的圖神經網絡需要滿足,隻在兩個節點自身标簽一樣且它們的鄰接節點一樣時,圖神經網絡将這兩個節點映射到相同的表征,即映射是單射性的。
  • 可重複集合/多重集(Multisets):元素可重複的集合,元素在集合中沒有順序關系。一個節點的所有鄰接節點是一個可重複集合,一個節點可以有重複的鄰接節點,鄰接節點沒有順序關系。是以GIN模型中生成節點表征的方法遵循WL Test算法更新節點标簽的過程。

在生成節點的表征後仍需要執行圖池化(或稱為圖讀出)操作得到圖表征,最簡單的圖讀出操作是做求和。由于每一層的節點表征都可能是重要的,是以在圖同構網絡中,不同層的節點表征在求和後被拼接,其數學定義如下,

采用拼接而不是相加的原因在于不同層節點的表征屬于不同的特征空間。未做嚴格的證明,這樣得到的圖的表示與WL Subtree Kernel得到的圖的表征是等價的。

圖表征子產品運作流程

(1)首先采用​

​GINNodeEmbedding​

​​子產品對圖上每一個節點做節點嵌入(Node Embedding),得到節點表征;

(2)然後對節點表征做圖池化得到圖的表征;

(3)最後用一層線性變換對圖表征轉換為對圖的預測。

代碼實作如下:

import torch
from torch import nn
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
from gin_node import GINNodeEmbedding

class GINGraphRepr(nn.Module):

    def __init__(self, num_tasks=1, num_layers=5, emb_dim=300, residual=False, drop_ratio=0, JK="last", graph_pooling="sum"):
        """GIN Graph Pooling Module
        Args:
            num_tasks (int, optional): number of labels to be predicted. Defaults to 1 (控制了圖表征的次元,dimension of graph representation).
            num_tasks為圖表征次元,預設1
            num_layers (int, optional): GINConv層數,預設5
            emb_dim (int, optional): dimension of node embedding. Defaults to 300.
            emb_dim為節點表征的次元,預設為300
            residual (bool, optional): adding residual connection or not. Defaults to False.
            drop_ratio (float, optional): dropout rate. Defaults to 0.
            JK (str, optional): 可選的值為"last"和"sum"。選"last",隻取最後一層的結點的嵌入,選"sum"對各層的結點的嵌入求和。Defaults to "last".
            graph_pooling (str, optional): pooling method of node embedding. 可選的值為"sum","mean","max","attention"和"set2set"。 Defaults to "sum".
            graph_pooling為圖節點表征的池化方法

        Out:
            graph representation
        """
        super(GINGraphPooling, self).__init__()

        self.num_layers = num_layers
        self.drop_ratio = drop_ratio
        self.JK = JK
        self.emb_dim = emb_dim
        self.num_tasks = num_tasks

        if self.num_layers < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")
        
        # 用GINNodeEmbedding子產品對圖上每一個節點做節點嵌入,得到節點表征
        self.gnn_node = GINNodeEmbedding(num_layers, emb_dim, JK=JK, drop_ratio=drop_ratio, residual=residual)
        
        # 對節點表征做圖池化得到圖的表征
        # Pooling function to generate whole-graph embeddings
        if graph_pooling == "sum":
            self.pool = global_add_pool
        elif graph_pooling == "mean":
            self.pool = global_mean_pool
        elif graph_pooling == "max":
            self.pool = global_max_pool
        elif graph_pooling == "attention":
            self.pool = GlobalAttention(gate_nn=nn.Sequential(
                nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(), nn.Linear(emb_dim, 1)))
        elif graph_pooling == "set2set":
            self.pool = Set2Set(emb_dim, processing_steps=2)
        else:
            raise ValueError("Invalid graph pooling type.")

        if graph_pooling == "set2set":
            self.graph_pred_linear = nn.Linear(2*self.emb_dim, self.num_tasks)
        else:
            self.graph_pred_linear = nn.Linear(self.emb_dim, self.num_tasks)

    def forward(self, batched_data):
        h_node = self.gnn_node(batched_data)

        h_graph = self.pool(h_node, batched_data.batch)
        output = self.graph_pred_linear(h_graph)

        if self.training:
            return output
        else:
            # At inference time, relu is applied to output to ensure positivity
            # 因為預測目标的取值範圍就在 (0, 50] 内
            return torch.clamp(output, min=0, max=50)      

可以看到可選的基于結點表征計算得到圖表征的方法有:

  1. “sum”:
  • 對節點表征求和;
  • 使用子產品​​torch_geometric.nn.glob.global_add_pool​​。
  1. “mean”:
  • 對節點表征求平均;
  • 使用子產品​​torch_geometric.nn.glob.global_mean_pool​​。
  1. “max”:取節點表征的最大值。
  • 對一個batch中所有節點計算節點表征各個次元的最大值;
  • 使用子產品​​torch_geometric.nn.glob.global_max_pool​​。
  1. “attention”:
  • 基于Attention對節點表征權重求和;
  • 使用子產品​​torch_geometric.nn.glob.GlobalAttention​​;
  • 來自論文​​“Gated Graph Sequence Neural Networks”​​ 。
  1. “set2set”:
  1. 另一種基于Attention對節點表征權重求和的方法;
  2. 使用子產品​​torch_geometric.nn.glob.Set2Set​​;
  3. 來自論文​​“Order Matters: Sequence to sequence for sets”​​。

PyG中內建的所有的圖池化的方法可見于​​Global Pooling Layers​​。

2.基于圖同構網絡的節點嵌入子產品(GINNodeEmbedding Module)

此節點嵌入子產品基于多層​

​GINConv​

​​實作結點嵌入的計算。此處我們先忽略​

​GINConv​

​的實作。輸入到此節點嵌入子產品的節點屬性為類别型向量,

【GNN】task6-基于圖神經網絡的圖表征學習方法

(1)我們首先用​

​AtomEncoder​

​對其做嵌入得到第​

​0​

​層節點表征(稍後我們再對​

​AtomEncoder​

​做分析)。

(2)然後我們逐層計算節點表征,從第​

​1​

​​層開始到第​

​num_layers​

​層,每一層節點表征的計算都以上一層的節點表征​

​h_list[layer]​

​、邊​

​edge_index​

​和邊的屬性​

​edge_attr​

​為輸入。

(3)需要注意的是,​

​GINConv​

​的層數越多,此節點嵌入子產品的感受野(receptive field)越大,結點​

​i​

​的表征最遠能捕獲到結點​

​i​

​的距離為​

​num_layers​

​的鄰接節點的資訊。

import torch
from mol_encoder import AtomEncoder
from gin_conv import GINConv
import torch.nn.functional as F

# GNN to generate node embedding
class GINNodeEmbedding(torch.nn.Module):
    """
    Output:
        node representations
    """

    def __init__(self, num_layers, emb_dim, drop_ratio=0.5, JK="last", residual=False):
        """GIN Node Embedding Module"""

        super(GINNodeEmbedding, self).__init__()
        self.num_layers = num_layers
        self.drop_ratio = drop_ratio
        self.JK = JK
        # add residual connection or not
        self.residual = residual

        if self.num_layers < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        self.atom_encoder = AtomEncoder(emb_dim)

        # List of GNNs
        self.convs = torch.nn.ModuleList()
        self.batch_norms = torch.nn.ModuleList()

        for layer in range(num_layers):
            self.convs.append(GINConv(emb_dim))
            self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))

    def forward(self, batched_data):
        x, edge_index, edge_attr = batched_data.x, batched_data.edge_index, batched_data.edge_attr

        # computing input node embedding
        h_list = [self.atom_encoder(x)]  # 先将類别型原子屬性轉化為原子表征
        for layer in range(self.num_layers):
            h = self.convs[layer](h_list[layer], edge_index, edge_attr)
            h = self.batch_norms[layer](h)
            if layer == self.num_layers - 1:
                # remove relu for the last layer
                h = F.dropout(h, self.drop_ratio, training=self.training)
            else:
                h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)

            if self.residual:
                h += h_list[layer]

            h_list.append(h)

        # Different implementations of Jk-concat
        if self.JK == "last":
            node_representation = h_list[-1]
        elif self.JK == "sum":
            node_representation = 0
            for layer in range(self.num_layers + 1):
                node_representation += h_list[layer]

        return      

接下來我們來學習圖同構網絡的關鍵元件​

​GINConv​

​。

3.GINConv–圖同構卷積層

圖同構卷積層的數學定義如下:

PyG中已經實作了此子產品,我們可以通過​​​torch_geometric.nn.GINConv​​來使用PyG定義好的圖同構卷積層,然而該實作不支援存在邊屬性的圖。在這裡我們自己自定義一個支援邊屬性的​

​GINConv​

​子產品。

由于輸入的邊屬性為類别型,是以我們需要先将類别型邊屬性轉換為邊表征。我們定義的​

​GINConv​

​子產品遵循“消息傳遞、消息聚合、消息更新”這一過程。

  • 這一過程随着​

    ​self.propagate()​

    ​​方法的調用開始執行,該函數接收​

    ​edge_index​

    ​​,​

    ​x​

    ​​,​

    ​edge_attr​

    ​​此三個參數。​

    ​edge_index​

    ​​是形狀為​

    ​[2,num_edges]​

    ​的張量(tensor)。
def forward(self, x, edge_index, edge_attr):
        # 先将類别型邊屬性轉換為邊表征
        edge_embedding = self.bond_encoder(edge_attr)
        out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))
        return      
  • 在消息傳遞過程中,此張量首先被按行拆分為​

    ​x_i​

    ​​和​

    ​x_j​

    ​​張量,​

    ​x_j​

    ​​表示了消息傳遞的源節點,​

    ​x_i​

    ​表示了消息傳遞的目标節點。
  • 接着​

    ​message()​

    ​​方法被調用,此函數定義了從源節點傳入到目标節點的消息,在這裡要傳遞的消息是源節點表征與邊表征之和的​

    ​relu()​

    ​​的輸出。我們在​

    ​super(GINConv, self).__init__(aggr = "add")​

    ​​中定義了消息聚合方式為​

    ​add​

    ​​,那麼傳入給任一個目标節點的所有消息被求和得到​

    ​aggr_out​

    ​,它還是目标節點的中間過程的資訊。
def message(self, x_j, edge_attr):
        return F.relu(x_j + edge_attr)      
  • 接着執行消息更新過程,我們的類​

    ​GINConv​

    ​​繼承了​

    ​MessagePassing​

    ​​類,是以​

    ​update()​

    ​​函數被調用。然而我們希望對節點做消息更新中加入目标節點自身的消息,是以在​

    ​update​

    ​​函數中我們隻簡單傳回輸入的​

    ​aggr_out​

    ​。
def update(self, aggr_out):
        return      
  • 然後在​

    ​forward​

    ​​函數中我們執行​

    ​out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))​

    ​實作消息的更新。
def forward(self, x, edge_index, edge_attr):
        # 先将類别型邊屬性轉換為邊表征
        edge_embedding = self.bond_encoder(edge_attr)
        out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))
        return      

流程圖如下(如有錯誤,請多多指正~)

【GNN】task6-基于圖神經網絡的圖表征學習方法

上面的完整代碼:

import torch
from torch import nn
from torch_geometric.nn import MessagePassing
import torch.nn.functional as F
from ogb.graphproppred.mol_encoder import BondEncoder


### GIN convolution along the graph structure
class GINConv(MessagePassing):
    def __init__(self, emb_dim):
        '''
            emb_dim (int): node embedding dimensionality
        '''
        super(GINConv, self).__init__(aggr = "add")

        self.mlp = nn.Sequential(nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(), nn.Linear(emb_dim, emb_dim))
        self.eps = nn.Parameter(torch.Tensor([0]))
        self.bond_encoder = BondEncoder(emb_dim = emb_dim)

    def forward(self, x, edge_index, edge_attr):
        # 先将類别型邊屬性轉換為邊表征
        edge_embedding = self.bond_encoder(edge_attr)
        out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))
        return out

    def message(self, x_j, edge_attr):
        return F.relu(x_j + edge_attr)
        
    def update(self, aggr_out):
        return      

4.AtomEncoder 與 BondEncoder

由于在目前的例子中,節點(原子)和邊(化學鍵)的屬性都為離散值,它們屬于不同的空間,無法直接将它們融合在一起。通過嵌入(Embedding),我們可以将節點屬性和邊屬性分别映射到一個新的空間,在這個新的空間中,我們就可以對節點和邊進行資訊融合。在​

​GINConv​

​​中,​

​message()​

​​函數中的​

​x_j + edge_attr​

​ 操作執行了節點資訊和邊資訊的融合。

接下來,我們通過下方的代碼中的​

​AtomEncoder​

​類,來分析将節點屬性映射到一個新的空間是如何實作的:

  • ​full_atom_feature_dims​

    ​​ 是一個連結清單​

    ​list​

    ​​,存儲了節點屬性向量每一維可能取值的數量,即​

    ​X[i]​

    ​​ 可能的取值一共有​

    ​full_atom_feature_dims[i]​

    ​​種情況,​

    ​X​

    ​為節點屬性;
  • 節點屬性有多少維,那麼就需要有多少個嵌入函數,通過調用​

    ​torch.nn.Embedding(dim, emb_dim)​

    ​可以執行個體化一個嵌入函數;
  • ​torch.nn.Embedding(dim, emb_dim)​

    ​​,第一個參數​

    ​dim​

    ​​為被嵌入資料可能取值的數量,第一個參數​

    ​emb_dim​

    ​​為要映射到的空間的次元。得到的嵌入函數接受一個大于​

    ​0​

    ​​小于​

    ​dim​

    ​​的數,輸出一個次元為​

    ​emb_dim​

    ​的向量。嵌入函數也包含可訓練參數,通過對神經網絡的訓練,嵌入函數的輸出值能夠表達不同輸入值之間的相似性。
  • 在​

    ​forward()​

    ​函數中,我們對不同屬性值得到的不同嵌入向量進行了相加操作,實作了将節點的的不同屬性融合在一起。

​BondEncoder​

​​類與​

​AtomEncoder​

​類是類似的。

OGB簡介:

Open Graph Benchmark (OGB) 是斯坦福大學的一個圖深度學習的基準資料集。可以直接通過​

​pip install ogb​

​下載下傳;在OGB資料集上,帶虛拟節點的模型相對于原來的模型基本上都得到了不小的提升。

import torch
from ogb.utils.features import get_atom_feature_dims, get_bond_feature_dims 

full_atom_feature_dims = get_atom_feature_dims()
full_bond_feature_dims = get_bond_feature_dims()

class AtomEncoder(torch.nn.Module):

    def __init__(self, emb_dim):
        super(AtomEncoder, self).__init__()
        
        self.atom_embedding_list = torch.nn.ModuleList()

        for i, dim in enumerate(full_atom_feature_dims):
            emb = torch.nn.Embedding(dim, emb_dim)
            torch.nn.init.xavier_uniform_(emb.weight.data)
            self.atom_embedding_list.append(emb)

    def forward(self, x):
        x_embedding = 0在·
        for i in range(x.shape[1]):
            x_embedding += self.atom_embedding_list[i](x[:,i])

        return x_embedding


class BondEncoder(torch.nn.Module):
    
    def __init__(self, emb_dim):
        super(BondEncoder, self).__init__()
        
        self.bond_embedding_list = torch.nn.ModuleList()

        for i, dim in enumerate(full_bond_feature_dims):
            emb = torch.nn.Embedding(dim, emb_dim)
            torch.nn.init.xavier_uniform_(emb.weight.data)
            self.bond_embedding_list.append(emb)

    def forward(self, edge_attr):
        bond_embedding = 0
        for i in range(edge_attr.shape[1]):
            bond_embedding += self.bond_embedding_list[i](edge_attr[:,i])

        return bond_embedding   


if __name__ == '__main__':
    # from loader import GraphClassificationPygDataset
    # dataset = GraphClassificationPygDataset(name = 'tox21')
    
    from ogb.graphproppred.dataset_pyg import PygGraphPropPredDataset
    dataset = PygGraphPropPredDataset(name = 'ogbg-molhiv')
    
    atom_enc = AtomEncoder(100)
    bond_enc = BondEncoder(100)

    print(atom_enc(dataset[0].x))
    print(bond_enc(dataset[0].edge_attr))      

上面的tox21資料集(藥物分子的一個資料集​​https://tripod.nih.gov/tox21/challenge/data.jsp#​​​)導入的loader好像有點問題,換用了​

​ogbg-molhiv​

​​資料集并且用​

​PygGraphPropPredDataset​

​導入後代碼運作結果如下,可以得到節點屬性映射到一個新的空間後的節點屬性。

Downloading http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/hiv.zip
Downloaded 0.00 GB: 100%|████████████████████████████████████████████████████████████████| 3/3 [00:03<00:00,  1.01s/it]
Extracting dataset\hiv.zip
Processing...
Loading necessary files...
This might take a while.
 25%|██████████████████▏                                                      | 10223/41127 [00:00<00:00, 55068.97it/s]
Processing graphs...
100%|█████████████████████████████████████████████████████████████████████████| 41127/41127 [00:00<00:00, 65315.60it/s]
 20%|██████████████▉                                                           | 8282/41127 [00:00<00:00, 82582.23it/s]
Converting graphs into PyG objects...
100%|█████████████████████████████████████████████████████████████████████████| 41127/41127 [00:00<00:00, 46464.62it/s]
Saving...
Done!
tensor([[ 6.5439e-01, -8.1690e-01,  1.4390e-01,  ..., -5.1183e-01,
          7.9147e-01, -5.8364e-02],
        [ 6.1138e-01, -7.0429e-01, -4.4907e-02,  ..., -4.3347e-01,
          7.7738e-01, -3.7608e-01],
        [ 1.6390e-01, -2.4293e-01,  2.1663e-01,  ..., -7.5048e-04,
          6.4839e-01,  3.3260e-01],
        ...,
        [ 6.1138e-01, -7.0429e-01, -4.4907e-02,  ..., -4.3347e-01,
          7.7738e-01, -3.7608e-01],
        [ 6.5439e-01, -8.1690e-01,  1.4390e-01,  ..., -5.1183e-01,
          7.9147e-01, -5.8364e-02],
        [ 8.3114e-01, -3.9946e-01,  1.2739e-01,  ..., -3.1377e-01,
          5.1153e-01,  5.1009e-01]], grad_fn=<AddBackward0>)
tensor([[-0.2277,  0.0363,  0.2976,  ..., -0.1147, -0.0277, -0.2164],
        [-0.2277,  0.0363,  0.2976,  ..., -0.1147, -0.0277, -0.2164],
        [-0.2277,  0.0363,  0.2976,  ..., -0.1147, -0.0277, -0.2164],
        ...,
        [-0.2277,  0.0363,  0.2976,  ..., -0.1147, -0.0277, -0.2164],
        [-0.2277,  0.0363,  0.2976,  ..., -0.1147, -0.0277, -0.2164],
        [-0.2277,  0.0363,  0.2976,  ..., -0.1147, -0.0277, -0.2164]],
       grad_fn=<AddBackward0>)      

二、How Powerful are Graph Neural Networks?

提出圖同構網絡的論文:​​How Powerful are Graph Neural Networks? ​​

1.Motivation

新的圖神經網絡的設計大多基于經驗性的直覺、啟發式的方法和實驗性的試錯。人們對圖神經網絡的特性和局限性了解甚少,對圖神經網絡的表征能力學習的正式分析也很有限。

2.文章内容

  1. (理論上)圖神經網絡在區分圖結構方面最高能達到與WL Test一樣的能力。
  2. 确定了鄰接節點聚合方法和圖池化方法應具備的條件,在這些條件下,所産生的圖神經網絡能達到與WL Test一樣的能力。
  3. 分析過去流行的圖神經網絡變體(如GCN和GraphSAGE)無法區分一些結構的圖。
  4. 開發了一個簡單的圖神經網絡模型–圖同構網絡(Graph Isomorphism Network, GIN),并證明其分辨同構圖的能力和表示圖的能力與WL Test相當。

3.背景:Weisfeiler-Lehman Test (WL Test)

(1)圖同構性測試算法WL Test

背景介紹

兩個圖是同構的,意思是兩個圖擁有一樣的拓撲結構,也就是說,我們可以通過重新标記節點從一個圖轉換到另外一個圖。Weisfeiler-Lehman 圖的同構性測試算法,簡稱WL Test,是一種用于測試兩個圖是否同構的算法。

WL Test 的一維形式,類似于圖神經網絡中的鄰接節點聚合。WL Test

1)疊代地聚合節點及其鄰接節點的标簽,然後 2)将聚合的标簽散列(hash)成新标簽,該過程形式化為下方的公式,

符号:表示節點的第次疊代的标簽,第次疊代的标簽為節點原始标簽。

在疊代過程中,發現兩個圖之間的節點的标簽不同時,就可以确定這兩個圖是非同構的。需要注意的是節點标簽可能的取值隻能是有限個數。

【GNN】task6-基于圖神經網絡的圖表征學習方法
WL舉例說明(以一維為栗子)

給定兩個圖和,每個節點擁有标簽(實際中,一些圖沒有節點标簽,我們可以以節點的度作為标簽)。

【GNN】task6-基于圖神經網絡的圖表征學習方法

Weisfeiler-Leman Test 算法通過重複執行以下給節點打标簽的過程來實作圖是否同構的判斷:

第一步:聚合

聚合自身與鄰接節點的标簽得到一串字元串,自身标簽與鄰接節點的标簽中間用​

​,​

​分隔,鄰接節點的标簽按升序排序。排序的原因在于要保證單射性,即保證輸出的結果不因鄰接節點的順序改變而改變。

如下圖就是,每個節點有個一個label(此處表示節點的度)。

【GNN】task6-基于圖神經網絡的圖表征學習方法

如下圖,做标簽的擴充:做一階BFS,即隻周遊自己的鄰居,比如在下圖中G中原5号節點變成(5,234),這是因為原(5)節點的一階鄰居有2、3、4。

【GNN】task6-基于圖神經網絡的圖表征學習方法

第二步:标簽散列(哈希)

即标簽壓縮,将較長的字元串映射到一個簡短的标簽。

如下圖,僅僅是把擴充标簽映射成一個新标簽,如5,234映射為13

【GNN】task6-基于圖神經網絡的圖表征學習方法

注:怎樣的聚合函數是一個單射函數?

什麼是單射函數?

單射指不同的輸入值一定會對應到不同的函數值。如果對于每一個y存在最多一個定義域内的x,有f(x)=y,則函數f被稱為單射函數。

看一個栗子:

兩個節點v1和v2,其中v1的鄰接點是1個黃球和1個藍球,v2的鄰接點是2個鄰接點是2個黃球和2個藍球。最常用的聚合函數包含圖卷積網絡中所使用的均值聚合,以及GraphSAGE中常用的均值聚合或最大值聚合。

(1)如果使用均值聚合或者最大值聚合,聚合後v1的狀态是(黃,藍),而v2的狀态也是(黃,藍),顯然它們把本應不同的2個節點映射到了同一個狀态,這不滿足單射的定義。

(2)如果使用求和函數,v1的狀态是(黃,藍),而v2的狀态是(2×黃,2×藍),也就分開了。

【GNN】task6-基于圖神經網絡的圖表征學習方法

可以看出WL測試最大的特點是:對每個節點的子樹的聚合函數采用的是單射(Injective)的散列函數。

第三步:給節點重新打上标簽。

繼續一開始的栗子,

【GNN】task6-基于圖神經網絡的圖表征學習方法

第四步:數标簽

如下圖,在G網絡中,含有1号标簽2個,那麼第一個數字就是1。這些标簽的個數作為整個網絡的新特征。

【GNN】task6-基于圖神經網絡的圖表征學習方法

每重複一次以上的過程,就完成一次節點自身标簽與鄰接節點标簽的聚合。

第五步:判斷同構性

當出現兩個圖相同節點标簽的出現次數不一緻時,即可判斷兩個圖不相似。如果上述的步驟重複一定的次數後,沒有發現有相同節點标簽的出現次數不一緻的情況,那麼我們無法判斷兩個圖是否同構。

當兩個節點的層的标簽一樣時,表示分别以這兩個節點為根節點的WL子樹是一緻的。WL子樹與普通子樹不同,WL子樹包含重複的節點。下圖展示了一棵以1節點為根節點高為2的WL子樹。

【GNN】task6-基于圖神經網絡的圖表征學習方法

(2)WL Subtree Kernel圖相似性評估(定量化)

此方法來自于​​Weisfeiler-Lehman Graph Kernels​​。

  • WL測試不能保證對所有圖都有效,特别是對于具有高度對稱性的圖,如鍊式圖、完全圖、環圖和星圖,它會判斷錯誤。
  • WL測試隻能判斷兩個圖的相似性,無法衡量圖之間的相似性。要衡量兩個圖的相似性,我們用WL Subtree Kernel方法。

Weisfeiler-Lehman Graph Kernels 方法提出用WL子樹核衡量圖之間相似性。該方法使用WL Test不同疊代中的節點标簽計數作為圖的表征向量,它具有與WL Test相同的判别能力。在WL Test的第次疊代中,一個節點的标簽代表了以該節點為根的高度為的子樹結構。

該方法的思想是用WL Test算法得到節點的多層的标簽,然後我們可以分别統計圖中各類标簽出現的次數,存于一個向量,這個向量可以作為圖的表征。兩個圖的表征向量的内積,即可作為這兩個圖的相似性估計,内積越大表示相似性越高。

【GNN】task6-基于圖神經網絡的圖表征學習方法

4.小結

大部分空域圖神經網絡的更新步驟,和WL測試非常類似。就像消息傳遞網絡中歸納的架構,大部分基于空域的圖神經網絡都可以歸結為2個步驟:聚合鄰接點資訊(aggregate),更新節點資訊(combine)。

與WL測試一樣,在表達網絡結果時,一個節點的表征會由該結點的父結點的子樹資訊聚合而成。

正如上面提到的栗子中(下圖),均值聚合或者最大值聚合把栗子中的v1和v2兩個節點映射到了同一個狀态(錯誤),而如果使用求和函數則能正确分開兩者狀态。WL測試最大的特點是:對每個節點的子樹的聚合函數采用的是單射(​

​Injective​

​)的散列函數。

——由該特點我們可以通過設計一個單射聚合聚合函數來設計與WL一樣強大的圖卷積網絡(同時,圖同構網絡有強大的圖區分能力,适合圖分類任務)。

【GNN】task6-基于圖神經網絡的圖表征學習方法

三、作業

請畫出下方圖檔中的6号、3号和5号節點為根結點的從1層到3層的WL()子樹(即高為3)。

【GNN】task6-基于圖神經網絡的圖表征學習方法
【GNN】task6-基于圖神經網絡的圖表征學習方法

四、論文相關

(1)《How Powerful are Graph Neural Networks?》 ICLR 2019

這篇主要基于Weisfeiler-Lehman(WL) test 視角理論分析了GNN,證明WL是GNN上限,并分析GCN和GraphSAGE等主流GNN在捕獲圖結構上的不足和特性。

(2)《Weisfeiler and Leman Go Neural: Higher-Order Graph Neural Networks》AAAI-2019

文章在介紹相關方法時主要分成了兩部分,包括後面的對比試驗也是,文章将圖領域内的方法分為兩種:

(1)基于核的方法,例如基于随機遊走或者最短距離核心的等等算法,其中,WL算法是屬于基于核方法的一種方法。

(2)GNN系列的方法,比如Gated Graph Neural Networks,GraphSAGE, SplineCNN等等。

(3)《Weisfeiler and Leman Go Neural: Higher-Order Graph Neural Networks》AAAI-2019

1.證明了GNN在非同構圖區分上并不比WL算法強,并且在某種特定情況下,GNN與WL算法具有同等效力,是以也具有相同的問題

REFERENCE

  • 提出GlobalAttention的論文:​​“Gated Graph Sequence Neural Networks”​​
  • 提出Set2Set的論文:​​“Order Matters: Sequence to sequence for sets”​​
  • PyG中內建的所有的圖池化的方法:​​Global Pooling Layers​​
  • Weisfeiler-Lehman Test:​​Brendan L Douglas. The weisfeiler-lehman method and graph isomorphism testing. arXiv preprint arXiv:1101.5211, 2011.​​
  • datawhale course:https://github.com/datawhalechina
  • ​​Weisfeiler-Lehman Graph Kernels​​
  • ogb包的源碼:https://github.com/snap-stanford/ogb
  • ​​cs224w(圖機器學習)2021冬季課程學習筆記8 Colab 2​​
  • 馬騰飛《圖神經網絡-基礎與前沿》
  • ​​How Powerful are Graph Neural Networks?論文解讀​​

繼續閱讀