天天看點

PyG圖神經網絡架構--建構資訊傳遞網絡(MPN)資訊傳遞網絡

資訊傳遞網絡

将卷積算子泛化到不規則域中,通常表示為鄰域聚合或資訊傳遞模式。 x i ( k − 1 ) ∈ R F x_i^{(k-1)} \in \R^F xi(k−1)​∈RF 表示節點 i i i 在第 l − 1 l-1 l−1 層的節點特征, e j , i ∈ R D e_{j,i}\in \R^D ej,i​∈RD 表示 從節點 j j j 到節點 i i i 的特征,資訊傳遞圖神經網絡可以表示為:

x i ( k ) = γ ( k ) ( x i ( k − 1 ) , ∗ j ∈ N ( i ) ϕ ( k ) ( x i ( k − 1 ) , x j ( k − 1 ) , e j , i ) ) x_i^{(k)}=\gamma^{(k)}(x_i^{(k-1)},*_{j\in\mathcal{N}(i)}\phi^{(k)}(x_i^{(k-1)},x_j^{(k-1)},e_{j,i})) xi(k)​=γ(k)(xi(k−1)​,∗j∈N(i)​ϕ(k)(xi(k−1)​,xj(k−1)​,ej,i​)),

*表示一個可微的、置換不變的函數(如求和、平均或最大值), γ \gamma γ 和 ϕ \phi ϕ 表示可微函數(如MLPs)。

MessagePassing類

PyG提供了

MessagePassing

類,它通過自動處理消息傳遞過程來建立此類消息傳遞圖神經網絡。我們隻需要定義函數 ϕ \phi ϕ,,即

message()

, 以及 γ \gamma γ,即

update()

,同樣還有聚合方式,即

aggr = "add" or "mean" or "max"

MessagePassing(aggr="add", flow="source_to_target", node_dim=-2)

定義了聚合方式 (求和,平均或最大值),消息傳遞的流向 (sorce_to_target or target_to_source)。此外,

node_dim

表示沿着哪條軸線傳播。

MessagePassing.propagate(edge_index, size=None, **kwargs)

啟動消息傳播的函數。接收邊的索引以及其他參數,這些參數事建構消息和更新節點嵌入所必須的。注意,

propagate()

不僅限于在形狀為 [n,n]的方形鄰接矩陣中交換資訊,還可以傳遞 size=(n,m)作為一個參數,然後在形狀為 [n,m]的一般稀疏矩陣中交換資訊。如果設定為 None,則預設為方形矩陣。

實作GCN layer

GCN層的數學定義為:

x i ( k ) = ∑ j ∈ N ( i ) ∪ i 1 d e g ( i ) ⋅ d e g ( j ) ⋅ ( Θ T ⋅ x j ( k − 1 ) ) x_i^{(k)}=\sum_{j\in\mathcal{N}(i)\cup{i}}\frac{1}{\sqrt{deg(i)}\cdot\sqrt{deg(j)}}\cdot(\Theta^T\cdot x_j^{(k-1)}) xi(k)​=∑j∈N(i)∪i​deg(i)

​⋅deg(j)

​1​⋅(ΘT⋅xj(k−1)​),

其中,相鄰節點的特征首先經過權重矩陣 Θ \Theta Θ 進行變換,按節點的度進行歸一化,最後進行求和。可以分為5個步驟:

  1. 為鄰接矩陣加自環
  2. 對節點特征矩陣執行線性變換
  3. 計算歸一化系數
  4. 歸一化 ϕ \phi ϕ中的節點特征
  5. 求和

步驟1-3通常在消息傳遞前處理的,4-5可以用

MessagePassing

基類進行處理,實作如下:

import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')  # "Add" aggregation (Step 5).
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x)

        # Step 3: Compute normalization.
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Step 4-5: Start propagating messages.
        return self.propagate(edge_index, x=x, norm=norm)

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]

        # Step 4: Normalize node features.
        return norm.view(-1, 1) * x_j

           

繼續閱讀