天天看點

Learning to Compare: Relation Network for Few-Shot Learning 論文筆記前言實作方法網絡結構為什麼RL能work?

前言

近年來深度學習模型在視覺任務上取得了巨大的成功,但這種成功有一部分原因來自于龐大的标記資料以及大量的計算資源,這使得這些模型在處理幾乎沒有标記資料的新類時顯得非常乏力。對于我們人類來說,在識别物體時,僅需少量的圖像,或者甚至不需要圖像而僅僅根據對物體的描述,就能根據以往的知識來識别物體。這是由于我們人類有先驗知識,我們會利用自己的先驗知識進行學習。如何讓模型能夠實作這種快速學習呢?元學習(meta learning)就是一種方法,也即學會學習。

本文就是利用對比來實作元學習,通過學習一個可轉移的深度度量來比較圖像之間的關系,即小樣本學習;或者比較圖像與類描述之間的關系,即零樣本學習。現有的小樣本學習方法通常将訓練分解為一個輔助的元學習階段,在該階段中,以良好的初始條件、embedding或優化政策來學習可轉移的知識,也就是先驗知識。但是這些方法要麼需要複雜的inference機制,要麼需要複雜的RNN結構,要麼通過優化政策進行微調來進行小樣本學習,總之就是很複雜就對了,而本文提出的方法很簡潔,也很靈活。

具體來說就是,提出了一個具有兩個分支的Relation Network(RN),它通過比較query圖像與每個新類中的少量樣本圖像之間的關系,來進行小樣本學習:

  • 首先,嵌入子產品(embedding model)為query和training圖像生成各自的embedding;
  • 然後,通過一個關系子產品(relation model)對這些embedding進行比較,判斷它們的類别是否比對。

RN的訓練同樣采用了episode政策,嵌入子產品和關系子產品都是端到端的元學習,注意RN中是一種可學習的非線性比較器,也就是一種可學習的非線性度量,這與MatchingNet和PrototypicalNet不同,MatchingNet中使用的是餘弦距離,PrototypicalNet中是固定的線性度量,即平方歐氏距離。本文的RL比其它的方法更簡單,因為沒有使用RNN;也比其它的方法更快,因為沒有微調。而且RL也可以直接泛化到零樣本學習中,即在關系模型中比較query圖像的embedding與類描述的embedding即可。

實作方法

1. 資料處理

對于小樣本學習任務,有三種資料集:訓練集,支援集和測試集。支援集和測試集共享同一個标簽空間,而訓練集有自己的标簽空間,并且不和另外兩種資料集共享。如果支援集中有 C C C個類,每個類有 K K K個帶标簽的樣本,那麼就可以稱為 C C C-way K K K-shot。

雖然隻用支援集原則上也可以訓練出一個分類器,以将标簽 y ^ \hat y y^​分類給測試集中的樣本 x ^ \hat x x^,但由于支援集中缺少帶标簽的樣本,由此訓練出的分類器的性能并不能讓人滿意。是以就要在訓練集上進行元學習,以提取出先驗知識,進而可以更好的在支援集上進行小樣本學習,進一步更好的對測試集進行分類。

一種有效利用訓練集的方法就是通過基于episode的訓練來模拟小樣本學習。在每次疊代中,一個episode是指,從訓練集中随機選出 C C C個類别,每個類中選擇 K K K個帶标簽的樣本作為一個樣本集(sample set) S = { ( x i , y i ) } i = 1 m S=\lbrace (x_i,y_i) \rbrace ^m_{i=1} S={(xi​,yi​)}i=1m​,然後從每個類剩下的樣本中選出一部分作為查詢集(query set) Q = { ( x j , y j ) } j = 1 n Q=\lbrace (x_j,y_j) \rbrace ^n_{j=1} Q={(xj​,yj​)}j=1n​,該樣本/查詢集旨在模拟測試時遇到的支援/測試集,通過樣本/查詢集訓練的模型也能用支援集來進一步微調。本文的實驗就是用的這種基于episode的訓練政策。

2. 模型

one-shot

RN包括兩個子產品:嵌入子產品 f φ f_{\varphi} fφ​和關系子產品 g ϕ g_{\phi} gϕ​,如下圖所示:

Learning to Compare: Relation Network for Few-Shot Learning 論文筆記前言實作方法網絡結構為什麼RL能work?

對于one-shot來說,就是樣本集 S S S中每個類隻有一個樣本,查詢集 Q Q Q無所謂。将查詢集 Q Q Q中的樣本 x j x_j xj​和樣本集 S S S中的樣本 x i x_i xi​送入嵌入子產品 f φ f_{\varphi} fφ​中,生成特征圖 f φ ( x j ) f_{\varphi}(x_j) fφ​(xj​)和 f ϕ ( x i ) f_{\phi}(x_i) fϕ​(xi​),然後這兩個特征圖通過 C ( f φ ( x j ) , f ϕ ( x i ) ) C(f_{\varphi}(x_j),f_{\phi}(x_i)) C(fφ​(xj​),fϕ​(xi​))操作連結到一起,這裡的 C ( ⋅ , ⋅ ) C(\cdot , \cdot) C(⋅,⋅)表示特征圖在深度上的連結。然後将連結起來的特征圖送入關系子產品 g ϕ g_{\phi} gϕ​中,生成一個在0和1之間的标量,表示 x i x_i xi​和 x j x_j xj​之間的相似性,被稱為關系分數。

是以,在 C C C-way one-shot設定下,共生成了 C C C個關系分數 r i , j r_{i,j} ri,j​:

Learning to Compare: Relation Network for Few-Shot Learning 論文筆記前言實作方法網絡結構為什麼RL能work?

K-shot

對于 K K K-shot來說,就是在 K > 1 K>1 K>1的情況下,也就是說樣本集 S S S中每個類的樣本數量大于1,查詢集 Q Q Q還是無所謂。那麼此時将每個類的所有樣本在嵌入子產品的輸出進行element-wise的相加,得到樣本集 S S S中每個類的特征圖,然後和one-shot一樣,與 Q Q Q中樣本的特征圖結合起來。

是以,不管是one-shot還是few-shot,對于 Q Q Q中的一個查詢樣本來說,關系分數的個數總是 C C C: Q Q Q中的每個查詢樣本,都要和 S S S中的每個類進行比較,看它和哪個類最相似,隻不過one-shot情況下 S S S中的每個類隻有一個樣本,而 K K K-shot情況下 S S S中的每個類有多個樣本,不過這多個樣本還是形成了一個屬于該類的特征圖。總共有 C C C個類,是以關系分數的個數就是 C C C

zero-shot

zero-shot大概類似于one-shot,隻不過不同于one-shot中支援集中每個類隻有一個樣本,zero-shot中每個類有一個語義向量 v c v_c vc​,那麼由此對RL所做的修改為:使用第二個異構嵌入子產品 f φ 2 f_{\varphi_2} fφ2​​來處理每個類的語義向量,關系子產品還是和以前一樣,那麼每個查詢樣本 x j x_j xj​的關系分數為:

Learning to Compare: Relation Network for Few-Shot Learning 論文筆記前言實作方法網絡結構為什麼RL能work?

3. 損失函數

本文使用均方誤差(MSE)來訓練模型,将關系分數 r i , j r_{i,j} ri,j​回歸到gt:兩個比對的樣本之間的相似性為1,不比對的則為0:

Learning to Compare: Relation Network for Few-Shot Learning 論文筆記前言實作方法網絡結構為什麼RL能work?

這樣看的話,就很像一個分類問題,即判斷是否屬于某一類别,是為1,不是為0;但從概念上來說本文是在預測關系分數,盡管為了回歸到gt隻能生成{0,1}中的某個值,但這仍然是一個回歸問題。

網絡結構

few-shot

大多數小樣本學習的模型使用4個卷積塊來組成嵌入子產品,本文也采用的是這樣的結構,如下圖所示。每個卷積塊包括一個3x3x64的卷積,一個批歸一化(batch normalisation)和一個ReLU非線性層,隻有前兩個卷積塊有2x2的最大池化層,後兩個沒有,隻是因為嵌入子產品輸出的特征圖還要進一步在關系子產品中進行卷積操作。關系子產品包括兩個卷積塊和兩個全連接配接層,每個卷積塊包括一個3x3x64的卷積,後跟批歸一化和ReLU非線性層,還有一個2x2的最大池化層,兩個全連接配接層分别是8維和1維的。除了輸出層是sigmoid外,所有的全連接配接層都是ReLU,輸出層的sigmoid是為了生成在合理範圍内的關系分數。

Learning to Compare: Relation Network for Few-Shot Learning 論文筆記前言實作方法網絡結構為什麼RL能work?

zero-shot

零樣本學習的網絡結構如下圖所示,其中DNN子網是在ImageNet上經過預訓練的一個現成的網絡。

Learning to Compare: Relation Network for Few-Shot Learning 論文筆記前言實作方法網絡結構為什麼RL能work?

為什麼RL能work?

與以往的小樣本學習研究相比,它們采用的是固定度量(如餘弦距離或平方歐氏距離)或固定特征(根據固定度量學到的embedding),和淺學習度量,本文提出的RL可以看作是學習深度embedding和深度非線性度量。

那麼這為什麼會有用呢?通過使用一個靈活的逼近函數來學習相似性,能夠以資料驅動的方式學習到一個很好的度量,而不是手動選擇正确的度量。像MatchingNet和PrototypicalNet中固定的度量假設特征隻在元素方面進行比較,而與RL最相關的PrototypicalNet還假設embedding後的特征具有線性可分離性。這些都嚴重依賴于嵌入網絡的有效性,是以受到嵌入網絡生成不充分的差別表示的程度的限制。而在RL中,通過深度學習非線性相似度量和embedding,使得網絡能夠更好的識别比對/不比對的樣本對兒。

繼續閱讀