天天看點

論文閱讀筆記《Instance Credibility Inference for Few-Shot Learning》

小樣本學習&元學習經典論文整理||持續更新

核心思想

  本文提出一種基于資料增強的小樣本學習算法(ICI)。本文的資料增強是通過自訓練(self-training)的方式實作的,具體而言就是利用有标簽的樣本先訓練得到一個分類器,然後預測無标簽樣本,得到僞标簽。選擇僞标簽中置信度較高的樣本,補充到訓練集中,實作資料擴充。通過疊代訓練的方式逐漸改善分類器的效果。網絡流程如下圖所示

論文閱讀筆記《Instance Credibility Inference for Few-Shot Learning》

  首先利用有标簽樣本訓練特征提取器和線性分類器,然後無标簽的樣本經過特征提取和簡單的線性分類後得到預測的僞标簽,利用執行個體置信度推斷子產品(Instance Credibility Inference,ICI)選擇出置信度較高的樣本和僞标簽,并利用其擴充支援集,而置信度較低的樣本則用于更新無标簽資料集。整個過程中最重要的一點就是如何計算預測得到的僞标簽的置信度,進而避免将分類錯誤的樣本補充到支援集中,導緻資料集被污染。下面具體介紹ICI子產品的處理過程,無論對于有标簽樣本還是無标簽樣本,網絡的預測結果 y i y_i yi​計算方式如下

論文閱讀筆記《Instance Credibility Inference for Few-Shot Learning》

式中 x i x_i xi​表示樣本對應的特征向量(特征提取網絡輸出的特征向量經過PCA降維後得到), β \beta β表示分類器的系數矩陣, ε i \varepsilon _i εi​表示均值為0,方差為 σ \sigma σ的高斯噪聲, γ i j \gamma_{ij} γij​用于修正執行個體 i i i被配置設定給類别 j j j的機率, γ i j \gamma_{ij} γij​的模越大,執行個體 i i i被配置設定給類别 j j j的難度越大。那麼本文的優化目标為

論文閱讀筆記《Instance Credibility Inference for Few-Shot Learning》

式中 R ( γ ) = ∑ i = 1 n ∥ γ i ∥ 2 R(\gamma)=\sum^n_{i=1}\left \|\gamma_i \right \|_2 R(γ)=∑i=1n​∥γi​∥2​表示懲罰項, λ \lambda λ表示懲罰項系數。為求解上述目标,本文的損失函數如下

論文閱讀筆記《Instance Credibility Inference for Few-Shot Learning》

令 ∂ L ∂ β = 0 \frac{\partial L}{\partial \beta}=0 ∂β∂L​=0可得

論文閱讀筆記《Instance Credibility Inference for Few-Shot Learning》

式中 ( ) † ()^{\dagger } ()†表示廣義逆矩陣。但值得注意的是,本文希望用 γ \gamma γ來度量執行個體的置信度,而不是用 β ^ \hat{\beta} β^​,這是因為簡單的線性分類器不足以對各種類别的樣本進行很好的分類,而且 β ^ \hat{\beta} β^​的值本身也依賴于 γ \gamma γ的取值。是以我們将上式代入損失函數 L L L中得到下式

論文閱讀筆記《Instance Credibility Inference for Few-Shot Learning》

式中 H = X ( X T X ) † X T H=X(X^TX)^{\dagger }X^T H=X(XTX)†XT。令 X ~ = ( I − H ) , Y ~ = X ~ Y \tilde{X}=(I-H),\tilde{Y}=\tilde{X}Y X~=(I−H),Y~=X~Y,則上式可簡化為

論文閱讀筆記《Instance Credibility Inference for Few-Shot Learning》

利用塊下降算法可以求解上式。首先 λ \lambda λ存在一個理論值,使得上式的解均為0,該理論值如下

論文閱讀筆記《Instance Credibility Inference for Few-Shot Learning》

那麼我們可以得到由0到 λ m a x \lambda_{max} λmax​之間一系列的 λ s \lambda_s λs​,對于每個 λ \lambda λ在求解目标函數時,都能獲得一條對應的 γ \gamma γ規則化路徑。而且當 λ \lambda λ由0變化到 ∞ \infty ∞時, γ \gamma γ的稀疏性不斷增強,直到他的所有元素都逐漸消失(vanish)。懲罰項 R ( γ ) R(\gamma) R(γ)會使得 γ \gamma γ一個執行個體接一個執行個體的消失,且消失的越早,則表明該執行個體的預測結果與真實值越為接近,是以根據 γ i \gamma_i γi​消失的順序可以得到對應的置信度 λ \lambda λ。

實作過程

網絡結構

  無具體介紹

損失函數

  見上文介紹

訓練政策

  訓練和推斷過程如下

論文閱讀筆記《Instance Credibility Inference for Few-Shot Learning》

創新點

  • 通過自訓練的方式擷取未标記樣本的僞标簽,并利用其擴充資料集,達到資料增強的目的
  • 設計了一種基于統計學的僞标簽置信度度量方法,選擇出置信度最高的樣本,用于支援資料集的補充

算法評價

  本文的整體思想并不複雜,了解難點主要集中在ICI子產品進行置信度度量的方面。本文提出的置信度度量方法是基于統計資訊的,根據實驗結果來看其性能提升作用還是比較明顯的,在多個資料集上都取得了SOTA的成績。

如果大家對于深度學習與計算機視覺領域感興趣,希望獲得更多的知識分享與最新的論文解讀,歡迎關注我的個人公衆号“深視”。

論文閱讀筆記《Instance Credibility Inference for Few-Shot Learning》

繼續閱讀