一、mini-batch
在graph classification的一些基準資料集中,每個圖的樣本都很小,如果每次隻操作一個,不能充分利用GPU資源。是以考慮把它們分成多個mini-batch。
1、原理
mini-batch就是并行處理多個圖,這裡把多個圖的鄰接矩陣A1、A2、……拼接成一個大的矩陣,可以看作一個對角矩陣(出現了很多0元素,即稀疏矩陣的存儲)

在imgae、language領域中的mini-batch有兩種方法:rescaling、padding,把每個樣本都處理成一樣的size、一樣的shape。
但是這兩種方法都不适用于graph,會造成很多不必要的記憶體浪費。
2、代碼
PyG架構中的dataloader事先封裝好了
from torch_geometric.loader import DataLoader
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
for step, data in enumerate(train_loader):
print(f'Step {step + 1}:')
print('=======')
print(f'Number of graphs in the current batch: {data.num_graphs}')
print(data)
print()
輸出結果:
Step 1:
=======
Number of graphs in the current batch: 64
Batch(edge_attr=[2560, 4], edge_index=[2, 2560], x=[1154, 7], y=[64], batch=[1154], ptr=[65])
Step 2:
======= Number of graphs in the current batch: 64 Batch(edge_attr=[2454, 4], edge_index=[2, 2454], x=[1121, 7], y=[64],
batch=[1121], ptr=[65])
Step 3:
======= Number of graphs in the current batch: 22 Batch(edge_attr=[980, 4], edge_index=[2, 980], x=[439, 7], y=[22],
batch=[439], ptr=[23])