天天看點

6-1-資料完整存于記憶體的資料集類

資料完全存于記憶體的資料集類

在上一節内容中,我們學習了基于圖神經網絡的節點表征學習方法,并用了現成的很小的資料集實作了節點分類任務。在此第6節的上半部分,我們将學習在PyG中如何自定義一個資料完全存于記憶體的資料集類。

在PyG中,我們通過繼承​​InMemoryDataset​​類來自定義一個資料可全部存儲到記憶體的資料集類。

​<code>​InMemoryDataset​</code>​官方文檔:​​torch_geometric.data.InMemoryDataset​​

如上方的​​InMemoryDataset​​類的構造函數接口所示,每個資料集都要有一個根檔案夾(​<code>​root​</code>​),它訓示資料集應該被儲存在哪裡。在根目錄下至少有兩個檔案夾:

一個檔案夾為​<code>​raw_dir​</code>​,它用于存儲未處理的檔案,從網絡上下載下傳的資料集檔案會被存放到這裡;

另一個檔案夾為​<code>​processed_dir​</code>​,處理後的資料集被儲存到這裡。

此外,繼承​​InMemoryDataset​​類的每個資料集類可以傳遞一個​<code>​transform​</code>​函數,一個​<code>​pre_transform​</code>​函數和一個​<code>​pre_filter​</code>​函數,它們預設都為​<code>​None​</code>​。

​<code>​transform​</code>​函數接受​<code>​Data​</code>​對象為參數,對其轉換後傳回。此函數在每一次資料通路時被調用,是以它應該用于資料增廣(Data Augmentation)。

​<code>​pre_transform​</code>​函數接受 ​​Data​​對象為參數,對其轉換後傳回。此函數在樣本 ​​Data​​對象儲存到檔案前調用,是以它最好用于隻需要做一次的大量預計算。

​<code>​pre_filter​</code>​函數可以在儲存前手動過濾掉資料對象。該函數的一個用例是,過濾樣本類别。

為了建立一個​​InMemoryDataset​​,我們需要實作四個基本方法:

​​raw_file_names()​​這是一個屬性方法,傳回一個檔案名清單,檔案應該能在​<code>​raw_dir​</code>​檔案夾中找到,否則調用​<code>​process()​</code>​函數下載下傳檔案到​<code>​raw_dir​</code>​檔案夾。

​​processed_file_names()​​。這是一個屬性方法,傳回一個檔案名清單,檔案應該能在​<code>​processed_dir​</code>​檔案夾中找到,否則調用​<code>​process()​</code>​函數對樣本做預處理然後儲存到​<code>​processed_dir​</code>​檔案夾。

​​download()​​: 将原始資料檔案下載下傳到​<code>​raw_dir​</code>​檔案夾。

​​process()​​: 對樣本做預處理然後儲存到​<code>​processed_dir​</code>​檔案夾。

樣本從原始檔案轉換成 ​​Data​​類對象的過程定義在​<code>​process​</code>​函數中。在該函數中,有時我們需要讀取和建立一個 ​​Data​​對象的清單,并将其儲存到​<code>​processed_dir​</code>​中。由于python儲存一個巨大的清單是相當慢的,是以我們在儲存之前通過​​collate()​​函數将該清單集合成一個巨大的 ​​Data​​對象。該函數還會傳回一個切片字典,以便從這個對象中重構單個樣本。最後,我們需要在構造函數中把這​<code>​Data​</code>​對象和切片字典分别加載到屬性​<code>​self.data​</code>​和​<code>​self.slices​</code>​中。我們通過下面的例子來介紹生成一個​​​​InMemoryDataset​​​​子類對象時程式的運作流程。

由于我們手頭沒有實際應用中的資料集,是以我們以公開資料集​<code>​PubMed​</code>​為例子。​<code>​PubMed ​</code>​資料集存儲的是文章引用網絡,文章對應圖的結點,如果兩篇文章存在引用關系(無論引用與被引),則這兩篇文章對應的結點之間存在邊。該資料集來源于論文Revisiting Semi-Supervised Learning with Graph Embeddings。我們直接基于PyG中的​<code>​Planetoid​</code>​類修改得到下面的​<code>​PlanetoidPubMed​</code>​資料集類。

在我們生成一個​<code>​PlanetoidPubMed​</code>​類的對象時,程式運作流程如下:

首先檢查資料原始檔案是否已下載下傳:

檢查​<code>​self.raw_dir​</code>​目錄下是否存在​<code>​raw_file_names()​</code>​屬性方法傳回的每個檔案,

如有檔案不存在,則調用​<code>​download()​</code>​方法執行原始檔案下載下傳。

其中​<code>​self.raw_dir​</code>​為​<code>​osp.join(self.root, 'raw')​</code>​。

其次檢查資料是否經過處理:

首先檢查之前對資料做變換的方法:檢查​<code>​self.processed_dir​</code>​目錄下是否存在​<code>​pre_transform.pt​</code>​檔案:如果存在,意味着之前進行過資料變換,則需加載該檔案擷取之前所用的資料變換的方法,并檢查它與目前​<code>​pre_transform​</code>​參數指定的方法是否相同;如果不相同則會報出一個警告,“The pre_transform argument differs from the one used in ……”。

接着檢查之前的樣本過濾的方法:檢查​<code>​self.processed_dir​</code>​目錄下是否存在​<code>​pre_filter.pt​</code>​檔案,如果存在,意味着之前進行過樣本過濾,則需加載該檔案擷取之前所用的樣本過濾的方法,并檢查它與目前​<code>​pre_filter​</code>​參數指定的方法是否相同,如果不相同則會報出一個警告,“The pre_filter argument differs from the one used in ……”。其中​<code>​self.processed_dir​</code>​為​<code>​osp.join(self.root, 'processed')​</code>​。

接着檢查是否存在處理好的資料:檢查​<code>​self.processed_dir​</code>​目錄下是否存在​<code>​self.processed_paths​</code>​方法傳回的所有檔案,如有檔案不存在,意味着不存在已經處理好的樣本的檔案,如需執行以下的操作:

調用​<code>​process​</code>​方法,進行資料處理。

如果​<code>​pre_transform​</code>​參數不為​<code>​None​</code>​,則調用​<code>​pre_transform​</code>​方法進行資料處理。

如果​<code>​pre_filter​</code>​參數不為​<code>​None​</code>​,則進行樣本過濾(此例子中不需要進行樣本過濾,​<code>​pre_filter​</code>​參數始終為​<code>​None​</code>​)。

儲存處理好的資料到檔案,檔案存儲在​<code>​processed_paths()​</code>​屬性方法傳回的路徑。如果将資料儲存到多個檔案中,則傳回的路徑有多個。這些路徑都在​<code>​self.processed_dir​</code>​目錄下,以​<code>​processed_file_names()​</code>​屬性方法的傳回值為檔案名。

最後儲存新的​<code>​pre_transform.pt​</code>​檔案和​<code>​pre_filter.pt​</code>​檔案,其中分别存儲目前使用的資料處理方法和樣本過濾方法。

現在讓我們檢視這個資料集:

可以看到這個資料集包含三個分類任務,共19,717個結點,88,648條邊,節點特征次元為500。