資訊傳遞網絡
将卷積算子泛化到不規則域中,通常表示為鄰域聚合或資訊傳遞模式。 x i ( k − 1 ) ∈ R F x_i^{(k-1)} \in \R^F xi(k−1)∈RF 表示節點 i i i 在第 l − 1 l-1 l−1 層的節點特征, e j , i ∈ R D e_{j,i}\in \R^D ej,i∈RD 表示 從節點 j j j 到節點 i i i 的特征,資訊傳遞圖神經網絡可以表示為:
x i ( k ) = γ ( k ) ( x i ( k − 1 ) , ∗ j ∈ N ( i ) ϕ ( k ) ( x i ( k − 1 ) , x j ( k − 1 ) , e j , i ) ) x_i^{(k)}=\gamma^{(k)}(x_i^{(k-1)},*_{j\in\mathcal{N}(i)}\phi^{(k)}(x_i^{(k-1)},x_j^{(k-1)},e_{j,i})) xi(k)=γ(k)(xi(k−1),∗j∈N(i)ϕ(k)(xi(k−1),xj(k−1),ej,i)),
*表示一個可微的、置換不變的函數(如求和、平均或最大值), γ \gamma γ 和 ϕ \phi ϕ 表示可微函數(如MLPs)。
MessagePassing類
PyG提供了
MessagePassing
類,它通過自動處理消息傳遞過程來建立此類消息傳遞圖神經網絡。我們隻需要定義函數 ϕ \phi ϕ,,即
message()
, 以及 γ \gamma γ,即
update()
,同樣還有聚合方式,即
aggr = "add" or "mean" or "max"
。
MessagePassing(aggr="add", flow="source_to_target", node_dim=-2)
定義了聚合方式 (求和,平均或最大值),消息傳遞的流向 (sorce_to_target or target_to_source)。此外,
node_dim
表示沿着哪條軸線傳播。
MessagePassing.propagate(edge_index, size=None, **kwargs)
啟動消息傳播的函數。接收邊的索引以及其他參數,這些參數事建構消息和更新節點嵌入所必須的。注意,
propagate()
不僅限于在形狀為 [n,n]的方形鄰接矩陣中交換資訊,還可以傳遞 size=(n,m)作為一個參數,然後在形狀為 [n,m]的一般稀疏矩陣中交換資訊。如果設定為 None,則預設為方形矩陣。
實作GCN layer
GCN層的數學定義為:
x i ( k ) = ∑ j ∈ N ( i ) ∪ i 1 d e g ( i ) ⋅ d e g ( j ) ⋅ ( Θ T ⋅ x j ( k − 1 ) ) x_i^{(k)}=\sum_{j\in\mathcal{N}(i)\cup{i}}\frac{1}{\sqrt{deg(i)}\cdot\sqrt{deg(j)}}\cdot(\Theta^T\cdot x_j^{(k-1)}) xi(k)=∑j∈N(i)∪ideg(i)
⋅deg(j)
1⋅(ΘT⋅xj(k−1)),
其中,相鄰節點的特征首先經過權重矩陣 Θ \Theta Θ 進行變換,按節點的度進行歸一化,最後進行求和。可以分為5個步驟:
- 為鄰接矩陣加自環
- 對節點特征矩陣執行線性變換
- 計算歸一化系數
- 歸一化 ϕ \phi ϕ中的節點特征
- 求和
步驟1-3通常在消息傳遞前處理的,4-5可以用
MessagePassing
基類進行處理,實作如下:
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super().__init__(aggr='add') # "Add" aggregation (Step 5).
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
# x has shape [N, in_channels]
# edge_index has shape [2, E]
# Step 1: Add self-loops to the adjacency matrix.
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# Step 2: Linearly transform node feature matrix.
x = self.lin(x)
# Step 3: Compute normalization.
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
# Step 4-5: Start propagating messages.
return self.propagate(edge_index, x=x, norm=norm)
def message(self, x_j, norm):
# x_j has shape [E, out_channels]
# Step 4: Normalize node features.
return norm.view(-1, 1) * x_j