對于圖神經網絡來說,最常見和被廣泛使用的任務之一就是節點分類。
圖資料中的訓練、驗證和測試集中的每個節點都具有從一組預定義的類别中配置設定的一個類别,即正确的标注。
節點回歸任務也類似,訓練、驗證和測試集中的每個節點都被标注了一個正确的數字。
概述
為了對節點進行分類,圖神經網絡執行了
guide_cn-message-passing
中介紹的消息傳遞機制,利用節點自身的特征和其鄰節點及邊的特征來計算節點的隐藏表示。
消息傳遞可以重複多輪,以利用更大範圍的鄰居資訊。
編寫神經網絡模型
DGL提供了一些内置的圖卷積子產品,可以完成一輪消息傳遞計算。
本章中選擇 :class:
dgl.nn.pytorch.SAGEConv
作為示範的樣例代碼(針對MXNet和PyTorch後端也有對應的子產品),
它是GraphSAGE模型中使用的圖卷積子產品。
對于圖上的深度學習模型,通常需要一個多層的圖神經網絡,并在這個網絡中要進行多輪的資訊傳遞。
可以通過堆疊圖卷積子產品來實作這種網絡架構,具體如下所示。
# 建構一個2層的GNN模型
import dgl.nn as dglnn
import torch.nn as nn
import torch.nn.functional as F
class SAGE(nn.Module):
def __init__(self, in_feats, hid_feats, out_feats):
super().__init__()
# 執行個體化SAGEConve,in_feats是輸入特征的次元,out_feats是輸出特征的次元,aggregator_type是聚合函數的類型
self.conv1 = dglnn.SAGEConv(
in_feats=in_feats, out_feats=hid_feats, aggregator_type='mean')
self.conv2 = dglnn.SAGEConv(
in_feats=hid_feats, out_feats=out_feats, aggregator_type='mean')
def forward(self, graph, inputs):
# 輸入是節點的特征
h = self.conv1(graph, inputs)
h = F.relu(h)
h = self.conv2(graph, h)
return h
模型的訓練
全圖(使用所有的節點和邊的特征)上的訓練隻需要使用上面定義的模型進行前向傳播計算,并通過在訓練節點上比較預測和真實标簽來計算損失,進而完成後向傳播。
本節使用DGL内置的資料集 :class:
dgl.data.CiteseerGraphDataset
來展示模型的訓練。
節點特征和标簽存儲在其圖上,訓練、驗證和測試的分割也以布爾掩碼的形式存儲在圖上。
node_features = graph.ndata['feat']
node_labels = graph.ndata['label']
train_mask = graph.ndata['train_mask']
valid_mask = graph.ndata['val_mask']
test_mask = graph.ndata['test_mask']
n_features = node_features.shape[1]
n_labels = int(node_labels.max().item() + 1)
下面是通過使用準确性來評估模型的一個例子。
def evaluate(model, graph, features, labels, mask):
model.eval()
with torch.no_grad():
logits = model(graph, features)
logits = logits[mask]
labels = labels[mask]
_, indices = torch.max(logits, dim=1)
correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels)
使用者可以按如下方式實作模型的訓練。
model = SAGE(in_feats=n_features, hid_feats=100, out_feats=n_labels)
opt = torch.optim.Adam(model.parameters())
for epoch in range(10):
model.train()
# 使用所有節點(全圖)進行前向傳播計算
logits = model(graph, node_features)
# 計算損失值
loss = F.cross_entropy(logits[train_mask], node_labels[train_mask])
# 計算驗證集的準确度
acc = evaluate(model, graph, node_features, node_labels, valid_mask)
# 進行反向傳播計算
opt.zero_grad()
loss.backward()
opt.step()
print(loss.item())
# 如果需要的話,儲存訓練好的模型。本例中省略。
DGL的GraphSAGE樣例 <https://github.com/dmlc/dgl/blob/master/examples/pytorch/graphsage/train_full.py>
__
提供了一個端到端的同構圖節點分類的例子。使用者可以在
GraphSAGE
類中看到模型實作的細節。
這個模型具有可調節的層數、dropout機率,以及可定制的聚合函數和非線性函數。
異構圖上的節點分類模型的訓練
如果圖是異構的,使用者可能希望沿着所有邊類型從鄰居那裡收集消息。
使用者可以使用 :class:
dgl.nn.pytorch.HeteroGraphConv
子產品(針對MXNet和PyTorch後端也有對應的子產品)在所有邊類型上執行消息傳遞,
并為每種邊類型使用一種圖卷積子產品。
下面的代碼定義了一個異構圖卷積子產品。子產品首先對每種邊類型進行單獨的圖卷積計算,然後将每種邊類型上的消息聚合結果再相加,
并作為所有節點類型的最終結果。
# Define a Heterograph Conv model
class RGCN(nn.Module):
def __init__(self, in_feats, hid_feats, out_feats, rel_names):
super().__init__()
# 執行個體化HeteroGraphConv,in_feats是輸入特征的次元,out_feats是輸出特征的次元,aggregate是聚合函數的類型
self.conv1 = dglnn.HeteroGraphConv({
rel: dglnn.GraphConv(in_feats, hid_feats)
for rel in rel_names}, aggregate='sum')
self.conv2 = dglnn.HeteroGraphConv({
rel: dglnn.GraphConv(hid_feats, out_feats)
for rel in rel_names}, aggregate='sum')
def forward(self, graph, inputs):
# 輸入是節點的特征字典
h = self.conv1(graph, inputs)
h = {k: F.relu(v) for k, v in h.items()}
h = self.conv2(graph, h)
return h
dgl.nn.HeteroGraphConv
接收一個節點類型和節點特征張量的字典作為輸入,并傳回另一個節點類型和節點特征的字典。
本章的
guide_cn-training-heterogeneous-graph-example
中已經有了
user
和
item
的特征,使用者可用如下代碼擷取。
model = RGCN(n_hetero_features, 20, n_user_classes, hetero_graph.etypes)
user_feats = hetero_graph.nodes['user'].data['feature']
item_feats = hetero_graph.nodes['item'].data['feature']
labels = hetero_graph.nodes['user'].data['label']
train_mask = hetero_graph.nodes['user'].data['train_mask']
然後,使用者可以簡單地按如下形式進行前向傳播計算:
node_features = {'user': user_feats, 'item': item_feats}
h_dict = model(hetero_graph, {'user': user_feats, 'item': item_feats})
h_user = h_dict['user']
h_item = h_dict['item']
異構圖上模型的訓練和同構圖的模型訓練是一樣的,隻是這裡使用了一個包括節點表示的字典來計算預測值。
例如,如果隻預測
user
節點的類别,使用者可以從傳回的字典中提取
user
的節點嵌入。
opt = torch.optim.Adam(model.parameters())
for epoch in range(5):
model.train()
# 使用所有節點的特征進行前向傳播計算,并提取輸出的user節點嵌入
logits = model(hetero_graph, node_features)['user']
# 計算損失值
loss = F.cross_entropy(logits[train_mask], labels[train_mask])
# 計算驗證集的準确度。在本例中省略。
# 進行反向傳播計算
opt.zero_grad()
loss.backward()
opt.step()
print(loss.item())
# 如果需要的話,儲存訓練好的模型。本例中省略。
完整例子大家可以參考
DGL提供了一個用于節點分類的RGCN的端到端的例子
RGCN <https://github.com/dmlc/dgl/blob/master/examples/pytorch/rgcn-hetero/entity_classify.py>
__
。使用者可以在
RGCN模型實作檔案 <https://github.com/dmlc/dgl/blob/master/examples/pytorch/rgcn-hetero/model.py>
__
中檢視異構圖卷積
RelGraphConvLayer
的具體定義。