天天看点

6-1-数据完整存于内存的数据集类

数据完全存于内存的数据集类

在上一节内容中,我们学习了基于图神经网络的节点表征学习方法,并用了现成的很小的数据集实现了节点分类任务。在此第6节的上半部分,我们将学习在PyG中如何自定义一个数据完全存于内存的数据集类。

在PyG中,我们通过继承​​InMemoryDataset​​类来自定义一个数据可全部存储到内存的数据集类。

​<code>​InMemoryDataset​</code>​官方文档:​​torch_geometric.data.InMemoryDataset​​

如上方的​​InMemoryDataset​​类的构造函数接口所示,每个数据集都要有一个根文件夹(​<code>​root​</code>​),它指示数据集应该被保存在哪里。在根目录下至少有两个文件夹:

一个文件夹为​<code>​raw_dir​</code>​,它用于存储未处理的文件,从网络上下载的数据集文件会被存放到这里;

另一个文件夹为​<code>​processed_dir​</code>​,处理后的数据集被保存到这里。

此外,继承​​InMemoryDataset​​类的每个数据集类可以传递一个​<code>​transform​</code>​函数,一个​<code>​pre_transform​</code>​函数和一个​<code>​pre_filter​</code>​函数,它们默认都为​<code>​None​</code>​。

​<code>​transform​</code>​函数接受​<code>​Data​</code>​对象为参数,对其转换后返回。此函数在每一次数据访问时被调用,所以它应该用于数据增广(Data Augmentation)。

​<code>​pre_transform​</code>​函数接受 ​​Data​​对象为参数,对其转换后返回。此函数在样本 ​​Data​​对象保存到文件前调用,所以它最好用于只需要做一次的大量预计算。

​<code>​pre_filter​</code>​函数可以在保存前手动过滤掉数据对象。该函数的一个用例是,过滤样本类别。

为了创建一个​​InMemoryDataset​​,我们需要实现四个基本方法:

​​raw_file_names()​​这是一个属性方法,返回一个文件名列表,文件应该能在​<code>​raw_dir​</code>​文件夹中找到,否则调用​<code>​process()​</code>​函数下载文件到​<code>​raw_dir​</code>​文件夹。

​​processed_file_names()​​。这是一个属性方法,返回一个文件名列表,文件应该能在​<code>​processed_dir​</code>​文件夹中找到,否则调用​<code>​process()​</code>​函数对样本做预处理然后保存到​<code>​processed_dir​</code>​文件夹。

​​download()​​: 将原始数据文件下载到​<code>​raw_dir​</code>​文件夹。

​​process()​​: 对样本做预处理然后保存到​<code>​processed_dir​</code>​文件夹。

样本从原始文件转换成 ​​Data​​类对象的过程定义在​<code>​process​</code>​函数中。在该函数中,有时我们需要读取和创建一个 ​​Data​​对象的列表,并将其保存到​<code>​processed_dir​</code>​中。由于python保存一个巨大的列表是相当慢的,因此我们在保存之前通过​​collate()​​函数将该列表集合成一个巨大的 ​​Data​​对象。该函数还会返回一个切片字典,以便从这个对象中重构单个样本。最后,我们需要在构造函数中把这​<code>​Data​</code>​对象和切片字典分别加载到属性​<code>​self.data​</code>​和​<code>​self.slices​</code>​中。我们通过下面的例子来介绍生成一个​​​​InMemoryDataset​​​​子类对象时程序的运行流程。

由于我们手头没有实际应用中的数据集,因此我们以公开数据集​<code>​PubMed​</code>​为例子。​<code>​PubMed ​</code>​数据集存储的是文章引用网络,文章对应图的结点,如果两篇文章存在引用关系(无论引用与被引),则这两篇文章对应的结点之间存在边。该数据集来源于论文Revisiting Semi-Supervised Learning with Graph Embeddings。我们直接基于PyG中的​<code>​Planetoid​</code>​类修改得到下面的​<code>​PlanetoidPubMed​</code>​数据集类。

在我们生成一个​<code>​PlanetoidPubMed​</code>​类的对象时,程序运行流程如下:

首先检查数据原始文件是否已下载:

检查​<code>​self.raw_dir​</code>​目录下是否存在​<code>​raw_file_names()​</code>​属性方法返回的每个文件,

如有文件不存在,则调用​<code>​download()​</code>​方法执行原始文件下载。

其中​<code>​self.raw_dir​</code>​为​<code>​osp.join(self.root, 'raw')​</code>​。

其次检查数据是否经过处理:

首先检查之前对数据做变换的方法:检查​<code>​self.processed_dir​</code>​目录下是否存在​<code>​pre_transform.pt​</code>​文件:如果存在,意味着之前进行过数据变换,则需加载该文件获取之前所用的数据变换的方法,并检查它与当前​<code>​pre_transform​</code>​参数指定的方法是否相同;如果不相同则会报出一个警告,“The pre_transform argument differs from the one used in ……”。

接着检查之前的样本过滤的方法:检查​<code>​self.processed_dir​</code>​目录下是否存在​<code>​pre_filter.pt​</code>​文件,如果存在,意味着之前进行过样本过滤,则需加载该文件获取之前所用的样本过滤的方法,并检查它与当前​<code>​pre_filter​</code>​参数指定的方法是否相同,如果不相同则会报出一个警告,“The pre_filter argument differs from the one used in ……”。其中​<code>​self.processed_dir​</code>​为​<code>​osp.join(self.root, 'processed')​</code>​。

接着检查是否存在处理好的数据:检查​<code>​self.processed_dir​</code>​目录下是否存在​<code>​self.processed_paths​</code>​方法返回的所有文件,如有文件不存在,意味着不存在已经处理好的样本的文件,如需执行以下的操作:

调用​<code>​process​</code>​方法,进行数据处理。

如果​<code>​pre_transform​</code>​参数不为​<code>​None​</code>​,则调用​<code>​pre_transform​</code>​方法进行数据处理。

如果​<code>​pre_filter​</code>​参数不为​<code>​None​</code>​,则进行样本过滤(此例子中不需要进行样本过滤,​<code>​pre_filter​</code>​参数始终为​<code>​None​</code>​)。

保存处理好的数据到文件,文件存储在​<code>​processed_paths()​</code>​属性方法返回的路径。如果将数据保存到多个文件中,则返回的路径有多个。这些路径都在​<code>​self.processed_dir​</code>​目录下,以​<code>​processed_file_names()​</code>​属性方法的返回值为文件名。

最后保存新的​<code>​pre_transform.pt​</code>​文件和​<code>​pre_filter.pt​</code>​文件,其中分别存储当前使用的数据处理方法和样本过滤方法。

现在让我们查看这个数据集:

可以看到这个数据集包含三个分类任务,共19,717个结点,88,648条边,节点特征维度为500。