關于整圖分類,有篇知乎寫的很好:【圖分類】10分鐘就學會的圖分類教程,基于pytorch和dgl。下面的代碼也是來者這篇知乎。
import dgl
import torch
from torch._C import device
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from dgl.data import MiniGCDataset
from dgl.nn.pytorch import GraphConv
from sklearn.metrics import accuracy_score
class Classifier(nn.Module):
def __init__(self, in_dim, hidden_dim, n_classes):
super(Classifier, self).__init__()
self.conv1 = GraphConv(in_dim, hidden_dim) # 定義第一層圖卷積
self.conv2 = GraphConv(hidden_dim, hidden_dim) # 定義第二層圖卷積
self.classify = nn.Linear(hidden_dim, n_classes) # 定義分類器
def forward(self, g):
"""g表示批處理後的大圖,N表示大圖的所有節點數量,n表示圖的數量
"""
# 為友善,我們用節點的度作為初始節點特征。對于無向圖,入度 = 出度
h = g.in_degrees().view(-1, 1).float() # [N, 1]
# 執行圖卷積和激活函數
h = F.relu(self.conv1(g, h)) # [N, hidden_dim]
h = F.relu(self.conv2(g, h)) # [N, hidden_dim]
g.ndata['h'] = h # 将特征賦予到圖的節點
# 通過平均池化每個節點的表示得到圖表示
hg = dgl.mean_nodes(g, 'h') # [n, hidden_dim]
return self.classify(hg) # [n, n_classes]
def collate(samples):
# 輸入參數samples是一個清單
# 清單裡的每個元素是圖和标簽對,如[(graph1, label1), (graph2, label2), ...]
# zip(*samples)是解壓操作,解壓為[(graph1, graph2, ...), (label1, label2, ...)]
graphs, labels = map(list, zip(*samples))
# dgl.batch 将一批圖看作是具有許多互不連接配接的元件構成的大型圖
return dgl.batch(graphs), torch.tensor(labels, dtype=torch.long)
# 建立訓練集和測試集
trainset = MiniGCDataset(2000, 10, 20) # 生成2000個圖,每個圖的最小節點數>=10, 最大節點數<=20
testset = MiniGCDataset(1000, 10, 20)
# 用pytorch的DataLoader和之前定義的collect函數
data_loader = DataLoader(trainset, batch_size=64, shuffle=True,
collate_fn=collate)
DEVICE = torch.device("cuda:2")
# 構造模型
model = Classifier(1, 256, trainset.num_classes)
model.to(DEVICE)
# 定義分類交叉熵損失
loss_func = nn.CrossEntropyLoss()
# 定義Adam優化器
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 模型訓練
model.train()
epoch_losses = []
for epoch in range(100):
epoch_loss = 0
for iter, (batchg, label) in enumerate(data_loader):
batchg, label = batchg.to(DEVICE), label.to(DEVICE)
prediction = model(batchg)
loss = loss_func(prediction, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.detach().item()
epoch_loss /= (iter + 1)
print('Epoch {}, loss {:.4f}'.format(epoch, epoch_loss))
epoch_losses.append(epoch_loss)
# 測試
test_loader = DataLoader(testset, batch_size=64, shuffle=False,
collate_fn=collate)
model.eval()
test_pred, test_label = [], []
with torch.no_grad():
for it, (batchg, label) in enumerate(test_loader):
batchg, label = batchg.to(DEVICE), label.to(DEVICE)
pred = torch.softmax(model(batchg), 1)
pred = torch.max(pred, 1)[1].view(-1)
test_pred += pred.detach().cpu().numpy().tolist()
test_label += label.cpu().numpy().tolist()
print("Test accuracy: ", accuracy_score(test_label, test_pred))
運作結果:
