1 DGL NN子產品的構造函數
構造函數完成以下幾個任務:
- 設定選項。
- 注冊可學習的參數或者子子產品。
- 初始化參數。
import torch.nn as nn
from dgl.utils import expand_as_pair
class SAGEConv(nn.Module):
def __init__(self,
in_feats,
out_feats,
aggregator_type,
bias=True,
norm=None,
activation=None):
super(SAGEConv, self).__init__()
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._out_feats = out_feats
self._aggre_type = aggregator_type
self.norm = norm
self.activation = activation
在構造函數中,使用者首先需要設定資料的次元。對于一般的PyTorch子產品,次元通常包括輸入的次元、輸出的次元和隐層的次元。
對于圖神經網絡,輸入次元可被分為源節點特征次元和目标節點特征次元。
除了資料次元,圖神經網絡的一個典型選項是聚合類型(
self._aggre_type
)。對于特定目标節點,聚合類型決定了如何聚合不同邊上的資訊。
常用的聚合類型包括
mean
、
sum
、
max
和
min
。一些子產品可能會使用更加複雜的聚合函數,比如
lstm
。
上面代碼裡的
norm
是用于特征歸一化的可調用函數。在SAGEConv論文裡,歸一化可以是L2歸一化:
h v = h v / ∥ h v ∥ 2 h_v = h_v / \lVert h_v \rVert_2 hv=hv/∥hv∥2
# 聚合類型:mean、max_pool、lstm、gcn
if aggregator_type not in ['mean', 'max_pool', 'lstm', 'gcn']:
raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))
if aggregator_type == 'max_pool':
self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
if aggregator_type == 'lstm':
self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
if aggregator_type in ['mean', 'max_pool', 'lstm']:
self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
self.reset_parameters()
注冊參數和子子產品。在SAGEConv中,子子產品根據聚合類型而有所不同。這些子產品是純PyTorch NN子產品,例如
nn.Linear
、
nn.LSTM
等。
構造函數的最後調用了
reset_parameters()
進行權重初始化。
def reset_parameters(self):
"""重新初始化可學習的參數"""
gain = nn.init.calculate_gain('relu')
if self._aggre_type == 'max_pool':
nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
if self._aggre_type == 'lstm':
self.lstm.reset_parameters()
if self._aggre_type != 'gcn':
nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)
2 編寫DGL NN子產品的forward函數
在NN子產品中,
forward()
函數執行了實際的消息傳遞和計算。與通常以張量為參數的PyTorch NN子產品相比,
DGL NN子產品額外增加了1個參數 :class:
dgl.DGLGraph
。
forward()
函數的内容一般可以分為3項操作:
- 檢測輸入圖對象是否符合規範。
- 消息傳遞和聚合。
- 聚合後,更新特征作為輸出。
下文展示了SAGEConv示例中的
forward()
函數。
輸入圖對象的規範檢測
def forward(self, graph, feat):
with graph.local_scope():
# 指定圖類型,然後根據圖類型擴充輸入特征
feat_src, feat_dst = expand_as_pair(feat, graph)
forward()
函數需要處理輸入的許多極端情況,這些情況可能導緻計算和消息傳遞中的值無效。
比如在 :class:
~dgl.nn.pytorch.conv.GraphConv
等conv子產品中,DGL會檢查輸入圖中是否有入度為0的節點。
當1個節點入度為0時,
mailbox
将為空,并且聚合函數的輸出值全為0,
這可能會導緻模型性能不佳。但是,在 :class:
~dgl.nn.pytorch.conv.SAGEConv
子產品中,被聚合的特征将會與節點的初始特征拼接起來,
forward()
函數的輸出不會全為0。在這種情況下,無需進行此類檢驗。
DGL NN子產品可在不同類型的圖輸入中重複使用,包括:同構圖、異構圖(:ref:
guide_cn-graph-heterogeneous
)和子圖塊(:ref:
guide_cn-minibatch
)。
SAGEConv的數學公式如下:

源節點特征
feat_src
和目标節點特征
feat_dst
需要根據圖類型被指定。
用于指定圖類型并将
feat
擴充為
feat_src
和
feat_dst
的函數是 :meth:
~dgl.utils.expand_as_pair
。
該函數的細節如下所示。
def expand_as_pair(input_, g=None):
if isinstance(input_, tuple):
# 二分圖的情況
return input_
elif g is not None and g.is_block:
# 子圖塊的情況
if isinstance(input_, Mapping):
input_dst = {
k: F.narrow_row(v, 0, g.number_of_dst_nodes(k))
for k, v in input_.items()}
else:
input_dst = F.narrow_row(input_, 0, g.number_of_dst_nodes())
return input_, input_dst
else:
# 同構圖的情況
return input_, input_
對于同構圖上的全圖訓練,源節點和目标節點相同,它們都是圖中的所有節點。
在異構圖的情況下,圖可以分為幾個二分圖,每種關系對應一個。關系表示為
(src_type, edge_type, dst_dtype)
。
當輸入特征
feat
是1個元組時,圖将會被視為二分圖。元組中的第1個元素為源節點特征,第2個元素為目标節點特征。
在小批次訓練中,計算應用于給定的一堆目标節點所采樣的子圖。子圖在DGL中稱為區塊(
block
)。
在區塊建立的階段,
dst nodes
位于節點清單的最前面。通過索引
[0:g.number_of_dst_nodes()]
可以找到
feat_dst
。
确定
feat_src
和
feat_dst
之後,以上3種圖類型的計算方法是相同的。
消息傳遞和聚合
import dgl.function as fn
import torch.nn.functional as F
from dgl.utils import check_eq_shape
if self._aggre_type == 'mean':
graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
h_neigh = graph.dstdata['neigh']
elif self._aggre_type == 'gcn':
check_eq_shape(feat)
graph.srcdata['h'] = feat_src
graph.dstdata['h'] = feat_dst
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))
# 除以入度
degs = graph.in_degrees().to(feat_dst)
h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
elif self._aggre_type == 'max_pool':
graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))
h_neigh = graph.dstdata['neigh']
else:
raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))
# GraphSAGE中gcn聚合不需要fc_self
if self._aggre_type == 'gcn':
rst = self.fc_neigh(h_neigh)
else:
rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
上面的代碼執行了消息傳遞和聚合的計算。這部分代碼會因子產品而異。
聚合後,更新特征作為輸出
# 激活函數
if self.activation is not None:
rst = self.activation(rst)
# 歸一化
if self.norm is not None:
rst = self.norm(rst)
return rst
forward()
函數的最後一部分是在完成消息聚合後更新節點的特征。
常見的更新操作是根據構造函數中設定的選項來應用激活函數和進行歸一化。
3 簡單的圖分類任務
在本教程中,我們将學習如何使用 DGL 執行圖分類,這個例子的任務目标就是對下面顯示的八種拓撲類型Grpah進行分類。
這裡我們直接使用 DGL 中合成資料集
data.MiniGCDataset
。資料集有八種不同類型的圖,每個類都有相同數量的圖樣本
from dgl.data import MiniGCDataset
import matplotlib.pyplot as plt
import networkx as nx
# A dataset with 80 samples, each graph is
# of size [10, 20]
dataset = MiniGCDataset(80, 10, 20)
graph, label = dataset[0]
fig, ax = plt.subplots()
nx.draw(graph.to_networkx(), ax=ax)
ax.set_title('Class: {:d}'.format(label))
plt.show()
Using backend: pytorch
建立graph的批資料
import dgl
import torch
def collate(samples):
# The input `samples` is a list of pairs
# (graph, label).
graphs, labels = map(list, zip(*samples))
batched_graph = dgl.batch(graphs)
return batched_graph, torch.tensor(labels,dtype=torch.long)
建構Graph分類器
from dgl.nn.pytorch import GraphConv
import torch.nn as nn
import torch.nn.functional as F
class Classifier(nn.Module):
def __init__(self, in_dim, hidden_dim, n_classes):
super(Classifier, self).__init__()
self.conv1 = GraphConv(in_dim, hidden_dim)
self.conv2 = GraphConv(hidden_dim, hidden_dim)
self.classify = nn.Linear(hidden_dim, n_classes)
def forward(self, g):
# Use node degree as the initial node feature. For undirected graphs, the in-degree
# is the same as the out_degree.
h = g.in_degrees().view(-1, 1).float()
# Perform graph convolution and activation function.
h = F.relu(self.conv1(g, h))
h = F.relu(self.conv2(g, h))
g.ndata['h'] = h
# Calculate graph representation by averaging all the node representations.
hg = dgl.mean_nodes(g, 'h')
return self.classify(hg)
import torch.optim as optim
from torch.utils.data import DataLoader
# Create training and test sets.
trainset = MiniGCDataset(320, 10, 20)
testset = MiniGCDataset(80, 10, 20)
# Use PyTorch's DataLoader and the collate function
# defined before.
data_loader = DataLoader(trainset, batch_size=32, shuffle=True,
collate_fn=collate)
# Create model
model = Classifier(1, 256, trainset.num_classes)
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
model.train()
epoch_losses = []
for epoch in range(80):
epoch_loss = 0
for iter, (bg, label) in enumerate(data_loader):
prediction = model(bg)
loss = loss_func(prediction, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.detach().item()
epoch_loss /= (iter + 1)
print('Epoch {}, loss {:.4f}'.format(epoch, epoch_loss))
epoch_losses.append(epoch_loss)
Epoch 0, loss 2.0010
Epoch 1, loss 1.9744
Epoch 2, loss 1.9551
Epoch 3, loss 1.9444
Epoch 4, loss 1.9318
Epoch 5, loss 1.9170
Epoch 6, loss 1.8928
Epoch 7, loss 1.8573
Epoch 8, loss 1.8212
Epoch 9, loss 1.7715
Epoch 10, loss 1.7152
Epoch 11, loss 1.6570
Epoch 12, loss 1.5885
Epoch 13, loss 1.5308
Epoch 14, loss 1.4719
Epoch 15, loss 1.4158
Epoch 16, loss 1.3515
Epoch 17, loss 1.2963
Epoch 18, loss 1.2417
Epoch 19, loss 1.1978
Epoch 20, loss 1.1698
Epoch 21, loss 1.1086
Epoch 22, loss 1.0780
Epoch 23, loss 1.0459
Epoch 24, loss 1.0192
Epoch 25, loss 1.0017
Epoch 26, loss 1.0297
Epoch 27, loss 0.9784
Epoch 28, loss 0.9486
Epoch 29, loss 0.9327
Epoch 30, loss 0.9133
Epoch 31, loss 0.9265
Epoch 32, loss 0.9177
Epoch 33, loss 0.9303
Epoch 34, loss 0.8666
Epoch 35, loss 0.8639
Epoch 36, loss 0.8474
Epoch 37, loss 0.8858
Epoch 38, loss 0.8393
Epoch 39, loss 0.8306
Epoch 40, loss 0.8204
Epoch 41, loss 0.8057
Epoch 42, loss 0.7998
Epoch 43, loss 0.7909
Epoch 44, loss 0.7840
Epoch 45, loss 0.7807
Epoch 46, loss 0.7882
Epoch 47, loss 0.7701
Epoch 48, loss 0.7612
Epoch 49, loss 0.7563
Epoch 50, loss 0.7430
Epoch 51, loss 0.7354
Epoch 52, loss 0.7357
Epoch 53, loss 0.7326
Epoch 54, loss 0.7249
Epoch 55, loss 0.7181
Epoch 56, loss 0.7146
Epoch 57, loss 0.7306
Epoch 58, loss 0.7143
Epoch 59, loss 0.7018
Epoch 60, loss 0.7130
Epoch 61, loss 0.7003
Epoch 62, loss 0.6977
Epoch 63, loss 0.7120
Epoch 64, loss 0.6979
Epoch 65, loss 0.7370
Epoch 66, loss 0.7223
Epoch 67, loss 0.6980
Epoch 68, loss 0.6891
Epoch 69, loss 0.6715
Epoch 70, loss 0.6736
Epoch 71, loss 0.6709
Epoch 72, loss 0.6583
Epoch 73, loss 0.6717
Epoch 74, loss 0.6683
Epoch 75, loss 0.6656
Epoch 76, loss 0.6477
Epoch 77, loss 0.6414
Epoch 78, loss 0.6442
Epoch 79, loss 0.6398
plt.title('cross entropy averaged over minibatches')
plt.plot(epoch_losses)
plt.show()
model.eval()
# Convert a list of tuples to two lists
test_X, test_Y = map(list, zip(*testset))
test_bg = dgl.batch(test_X)
test_Y = torch.tensor(test_Y).float().view(-1, 1)
probs_Y = torch.softmax(model(test_bg), 1)
sampled_Y = torch.multinomial(probs_Y, 1)
argmax_Y = torch.max(probs_Y, 1)[1].view(-1, 1)
print('Accuracy of sampled predictions on the test set: {:.4f}%'.format(
(test_Y == sampled_Y.float()).sum().item() / len(test_Y) * 100))
print('Accuracy of argmax predictions on the test set: {:4f}%'.format(
(test_Y == argmax_Y.float()).sum().item() / len(test_Y) * 100))
Accuracy of sampled predictions on the test set: 58.7500%
Accuracy of argmax predictions on the test set: 62.500000%