天天看点

Learning to Compare: Relation Network for Few-Shot Learning 论文笔记前言实现方法网络结构为什么RL能work?

前言

近年来深度学习模型在视觉任务上取得了巨大的成功,但这种成功有一部分原因来自于庞大的标记数据以及大量的计算资源,这使得这些模型在处理几乎没有标记数据的新类时显得非常乏力。对于我们人类来说,在识别物体时,仅需少量的图像,或者甚至不需要图像而仅仅根据对物体的描述,就能根据以往的知识来识别物体。这是由于我们人类有先验知识,我们会利用自己的先验知识进行学习。如何让模型能够实现这种快速学习呢?元学习(meta learning)就是一种方法,也即学会学习。

本文就是利用对比来实现元学习,通过学习一个可转移的深度度量来比较图像之间的关系,即小样本学习;或者比较图像与类描述之间的关系,即零样本学习。现有的小样本学习方法通常将训练分解为一个辅助的元学习阶段,在该阶段中,以良好的初始条件、embedding或优化策略来学习可转移的知识,也就是先验知识。但是这些方法要么需要复杂的inference机制,要么需要复杂的RNN结构,要么通过优化策略进行微调来进行小样本学习,总之就是很复杂就对了,而本文提出的方法很简洁,也很灵活。

具体来说就是,提出了一个具有两个分支的Relation Network(RN),它通过比较query图像与每个新类中的少量样本图像之间的关系,来进行小样本学习:

  • 首先,嵌入模块(embedding model)为query和training图像生成各自的embedding;
  • 然后,通过一个关系模块(relation model)对这些embedding进行比较,判断它们的类别是否匹配。

RN的训练同样采用了episode策略,嵌入模块和关系模块都是端到端的元学习,注意RN中是一种可学习的非线性比较器,也就是一种可学习的非线性度量,这与MatchingNet和PrototypicalNet不同,MatchingNet中使用的是余弦距离,PrototypicalNet中是固定的线性度量,即平方欧氏距离。本文的RL比其它的方法更简单,因为没有使用RNN;也比其它的方法更快,因为没有微调。而且RL也可以直接泛化到零样本学习中,即在关系模型中比较query图像的embedding与类描述的embedding即可。

实现方法

1. 数据处理

对于小样本学习任务,有三种数据集:训练集,支持集和测试集。支持集和测试集共享同一个标签空间,而训练集有自己的标签空间,并且不和另外两种数据集共享。如果支持集中有 C C C个类,每个类有 K K K个带标签的样本,那么就可以称为 C C C-way K K K-shot。

虽然只用支持集原则上也可以训练出一个分类器,以将标签 y ^ \hat y y^​分类给测试集中的样本 x ^ \hat x x^,但由于支持集中缺少带标签的样本,由此训练出的分类器的性能并不能让人满意。因此就要在训练集上进行元学习,以提取出先验知识,从而可以更好的在支持集上进行小样本学习,进一步更好的对测试集进行分类。

一种有效利用训练集的方法就是通过基于episode的训练来模拟小样本学习。在每次迭代中,一个episode是指,从训练集中随机选出 C C C个类别,每个类中选择 K K K个带标签的样本作为一个样本集(sample set) S = { ( x i , y i ) } i = 1 m S=\lbrace (x_i,y_i) \rbrace ^m_{i=1} S={(xi​,yi​)}i=1m​,然后从每个类剩下的样本中选出一部分作为查询集(query set) Q = { ( x j , y j ) } j = 1 n Q=\lbrace (x_j,y_j) \rbrace ^n_{j=1} Q={(xj​,yj​)}j=1n​,该样本/查询集旨在模拟测试时遇到的支持/测试集,通过样本/查询集训练的模型也能用支持集来进一步微调。本文的实验就是用的这种基于episode的训练策略。

2. 模型

one-shot

RN包括两个模块:嵌入模块 f φ f_{\varphi} fφ​和关系模块 g ϕ g_{\phi} gϕ​,如下图所示:

Learning to Compare: Relation Network for Few-Shot Learning 论文笔记前言实现方法网络结构为什么RL能work?

对于one-shot来说,就是样本集 S S S中每个类只有一个样本,查询集 Q Q Q无所谓。将查询集 Q Q Q中的样本 x j x_j xj​和样本集 S S S中的样本 x i x_i xi​送入嵌入模块 f φ f_{\varphi} fφ​中,生成特征图 f φ ( x j ) f_{\varphi}(x_j) fφ​(xj​)和 f ϕ ( x i ) f_{\phi}(x_i) fϕ​(xi​),然后这两个特征图通过 C ( f φ ( x j ) , f ϕ ( x i ) ) C(f_{\varphi}(x_j),f_{\phi}(x_i)) C(fφ​(xj​),fϕ​(xi​))操作连结到一起,这里的 C ( ⋅ , ⋅ ) C(\cdot , \cdot) C(⋅,⋅)表示特征图在深度上的连结。然后将连结起来的特征图送入关系模块 g ϕ g_{\phi} gϕ​中,生成一个在0和1之间的标量,表示 x i x_i xi​和 x j x_j xj​之间的相似性,被称为关系分数。

因此,在 C C C-way one-shot设置下,共生成了 C C C个关系分数 r i , j r_{i,j} ri,j​:

Learning to Compare: Relation Network for Few-Shot Learning 论文笔记前言实现方法网络结构为什么RL能work?

K-shot

对于 K K K-shot来说,就是在 K > 1 K>1 K>1的情况下,也就是说样本集 S S S中每个类的样本数量大于1,查询集 Q Q Q还是无所谓。那么此时将每个类的所有样本在嵌入模块的输出进行element-wise的相加,得到样本集 S S S中每个类的特征图,然后和one-shot一样,与 Q Q Q中样本的特征图结合起来。

因此,不管是one-shot还是few-shot,对于 Q Q Q中的一个查询样本来说,关系分数的个数总是 C C C: Q Q Q中的每个查询样本,都要和 S S S中的每个类进行比较,看它和哪个类最相似,只不过one-shot情况下 S S S中的每个类只有一个样本,而 K K K-shot情况下 S S S中的每个类有多个样本,不过这多个样本还是形成了一个属于该类的特征图。总共有 C C C个类,所以关系分数的个数就是 C C C

zero-shot

zero-shot大概类似于one-shot,只不过不同于one-shot中支持集中每个类只有一个样本,zero-shot中每个类有一个语义向量 v c v_c vc​,那么由此对RL所做的修改为:使用第二个异构嵌入模块 f φ 2 f_{\varphi_2} fφ2​​来处理每个类的语义向量,关系模块还是和以前一样,那么每个查询样本 x j x_j xj​的关系分数为:

Learning to Compare: Relation Network for Few-Shot Learning 论文笔记前言实现方法网络结构为什么RL能work?

3. 损失函数

本文使用均方误差(MSE)来训练模型,将关系分数 r i , j r_{i,j} ri,j​回归到gt:两个匹配的样本之间的相似性为1,不匹配的则为0:

Learning to Compare: Relation Network for Few-Shot Learning 论文笔记前言实现方法网络结构为什么RL能work?

这样看的话,就很像一个分类问题,即判断是否属于某一类别,是为1,不是为0;但从概念上来说本文是在预测关系分数,尽管为了回归到gt只能生成{0,1}中的某个值,但这仍然是一个回归问题。

网络结构

few-shot

大多数小样本学习的模型使用4个卷积块来组成嵌入模块,本文也采用的是这样的结构,如下图所示。每个卷积块包括一个3x3x64的卷积,一个批归一化(batch normalisation)和一个ReLU非线性层,只有前两个卷积块有2x2的最大池化层,后两个没有,只是因为嵌入模块输出的特征图还要进一步在关系模块中进行卷积操作。关系模块包括两个卷积块和两个全连接层,每个卷积块包括一个3x3x64的卷积,后跟批归一化和ReLU非线性层,还有一个2x2的最大池化层,两个全连接层分别是8维和1维的。除了输出层是sigmoid外,所有的全连接层都是ReLU,输出层的sigmoid是为了生成在合理范围内的关系分数。

Learning to Compare: Relation Network for Few-Shot Learning 论文笔记前言实现方法网络结构为什么RL能work?

zero-shot

零样本学习的网络结构如下图所示,其中DNN子网是在ImageNet上经过预训练的一个现成的网络。

Learning to Compare: Relation Network for Few-Shot Learning 论文笔记前言实现方法网络结构为什么RL能work?

为什么RL能work?

与以往的小样本学习研究相比,它们采用的是固定度量(如余弦距离或平方欧氏距离)或固定特征(根据固定度量学到的embedding),和浅学习度量,本文提出的RL可以看作是学习深度embedding和深度非线性度量。

那么这为什么会有用呢?通过使用一个灵活的逼近函数来学习相似性,能够以数据驱动的方式学习到一个很好的度量,而不是手动选择正确的度量。像MatchingNet和PrototypicalNet中固定的度量假设特征只在元素方面进行比较,而与RL最相关的PrototypicalNet还假设embedding后的特征具有线性可分离性。这些都严重依赖于嵌入网络的有效性,因此受到嵌入网络生成不充分的区别表示的程度的限制。而在RL中,通过深度学习非线性相似度量和embedding,使得网络能够更好的识别匹配/不匹配的样本对儿。

继续阅读