資料完全存于記憶體的資料集類
在上一節内容中,我們學習了基于圖神經網絡的節點表征學習方法,并用了現成的很小的資料集實作了節點分類任務。在此第6節的上半部分,我們将學習在PyG中如何自定義一個資料完全存于記憶體的資料集類。
在PyG中,我們通過繼承InMemoryDataset類來自定義一個資料可全部存儲到記憶體的資料集類。
<code>InMemoryDataset</code>官方文檔:torch_geometric.data.InMemoryDataset
如上方的InMemoryDataset類的構造函數接口所示,每個資料集都要有一個根檔案夾(<code>root</code>),它訓示資料集應該被儲存在哪裡。在根目錄下至少有兩個檔案夾:
一個檔案夾為<code>raw_dir</code>,它用于存儲未處理的檔案,從網絡上下載下傳的資料集檔案會被存放到這裡;
另一個檔案夾為<code>processed_dir</code>,處理後的資料集被儲存到這裡。
此外,繼承InMemoryDataset類的每個資料集類可以傳遞一個<code>transform</code>函數,一個<code>pre_transform</code>函數和一個<code>pre_filter</code>函數,它們預設都為<code>None</code>。
<code>transform</code>函數接受<code>Data</code>對象為參數,對其轉換後傳回。此函數在每一次資料通路時被調用,是以它應該用于資料增廣(Data Augmentation)。
<code>pre_transform</code>函數接受 Data對象為參數,對其轉換後傳回。此函數在樣本 Data對象儲存到檔案前調用,是以它最好用于隻需要做一次的大量預計算。
<code>pre_filter</code>函數可以在儲存前手動過濾掉資料對象。該函數的一個用例是,過濾樣本類别。
為了建立一個InMemoryDataset,我們需要實作四個基本方法:
raw_file_names()這是一個屬性方法,傳回一個檔案名清單,檔案應該能在<code>raw_dir</code>檔案夾中找到,否則調用<code>process()</code>函數下載下傳檔案到<code>raw_dir</code>檔案夾。
processed_file_names()。這是一個屬性方法,傳回一個檔案名清單,檔案應該能在<code>processed_dir</code>檔案夾中找到,否則調用<code>process()</code>函數對樣本做預處理然後儲存到<code>processed_dir</code>檔案夾。
download(): 将原始資料檔案下載下傳到<code>raw_dir</code>檔案夾。
process(): 對樣本做預處理然後儲存到<code>processed_dir</code>檔案夾。
樣本從原始檔案轉換成 Data類對象的過程定義在<code>process</code>函數中。在該函數中,有時我們需要讀取和建立一個 Data對象的清單,并将其儲存到<code>processed_dir</code>中。由于python儲存一個巨大的清單是相當慢的,是以我們在儲存之前通過collate()函數将該清單集合成一個巨大的 Data對象。該函數還會傳回一個切片字典,以便從這個對象中重構單個樣本。最後,我們需要在構造函數中把這<code>Data</code>對象和切片字典分别加載到屬性<code>self.data</code>和<code>self.slices</code>中。我們通過下面的例子來介紹生成一個InMemoryDataset子類對象時程式的運作流程。
由于我們手頭沒有實際應用中的資料集,是以我們以公開資料集<code>PubMed</code>為例子。<code>PubMed </code>資料集存儲的是文章引用網絡,文章對應圖的結點,如果兩篇文章存在引用關系(無論引用與被引),則這兩篇文章對應的結點之間存在邊。該資料集來源于論文Revisiting Semi-Supervised Learning with Graph Embeddings。我們直接基于PyG中的<code>Planetoid</code>類修改得到下面的<code>PlanetoidPubMed</code>資料集類。
在我們生成一個<code>PlanetoidPubMed</code>類的對象時,程式運作流程如下:
首先檢查資料原始檔案是否已下載下傳:
檢查<code>self.raw_dir</code>目錄下是否存在<code>raw_file_names()</code>屬性方法傳回的每個檔案,
如有檔案不存在,則調用<code>download()</code>方法執行原始檔案下載下傳。
其中<code>self.raw_dir</code>為<code>osp.join(self.root, 'raw')</code>。
其次檢查資料是否經過處理:
首先檢查之前對資料做變換的方法:檢查<code>self.processed_dir</code>目錄下是否存在<code>pre_transform.pt</code>檔案:如果存在,意味着之前進行過資料變換,則需加載該檔案擷取之前所用的資料變換的方法,并檢查它與目前<code>pre_transform</code>參數指定的方法是否相同;如果不相同則會報出一個警告,“The pre_transform argument differs from the one used in ……”。
接着檢查之前的樣本過濾的方法:檢查<code>self.processed_dir</code>目錄下是否存在<code>pre_filter.pt</code>檔案,如果存在,意味着之前進行過樣本過濾,則需加載該檔案擷取之前所用的樣本過濾的方法,并檢查它與目前<code>pre_filter</code>參數指定的方法是否相同,如果不相同則會報出一個警告,“The pre_filter argument differs from the one used in ……”。其中<code>self.processed_dir</code>為<code>osp.join(self.root, 'processed')</code>。
接着檢查是否存在處理好的資料:檢查<code>self.processed_dir</code>目錄下是否存在<code>self.processed_paths</code>方法傳回的所有檔案,如有檔案不存在,意味着不存在已經處理好的樣本的檔案,如需執行以下的操作:
調用<code>process</code>方法,進行資料處理。
如果<code>pre_transform</code>參數不為<code>None</code>,則調用<code>pre_transform</code>方法進行資料處理。
如果<code>pre_filter</code>參數不為<code>None</code>,則進行樣本過濾(此例子中不需要進行樣本過濾,<code>pre_filter</code>參數始終為<code>None</code>)。
儲存處理好的資料到檔案,檔案存儲在<code>processed_paths()</code>屬性方法傳回的路徑。如果将資料儲存到多個檔案中,則傳回的路徑有多個。這些路徑都在<code>self.processed_dir</code>目錄下,以<code>processed_file_names()</code>屬性方法的傳回值為檔案名。
最後儲存新的<code>pre_transform.pt</code>檔案和<code>pre_filter.pt</code>檔案,其中分别存儲目前使用的資料處理方法和樣本過濾方法。
現在讓我們檢視這個資料集:
可以看到這個資料集包含三個分類任務,共19,717個結點,88,648條邊,節點特征次元為500。