天天看點

NeurIPS 2022 | S-Prompts:擺脫新舊任務零和遊戲,實作雙赢的域增量學習方法一、概述二、方法三、實驗結果四、方法總結與未來展望

本文是對我們NeurIPS 2022被接收的文章“S-Prompts Learning with Pre-trained Transformers: An Occam’s Razor for Domain Incremental Learning”的介紹。在該工作中我們提出一個針對域增量學習的簡單高效的方法(S-Prompts)。我們設計Prompts訓練政策對每個域的知識進行獨立學習,進而将預訓練模型增量地遷移學習不同域。所提出的方法可以讓新舊知識互不幹擾,并達到雙赢的結果。很榮幸地,我們的文章被NeurIPS 2022收錄,項目代碼即将開源,歡迎大家試用。
NeurIPS 2022 | S-Prompts:擺脫新舊任務零和遊戲,實作雙赢的域增量學習方法一、概述二、方法三、實驗結果四、方法總結與未來展望

論文連結:

https://arxiv.org/pdf/2207.12819.pdf

代碼連結:

https://github.com/iamwangyabin/S-Prompts (暫未開放,敬請期待)

一、概述

增量學習(連續學習)目标是在資料流中增量地訓練一個機器學習模型,使得模型能夠在獲得新知識的同時不遺忘已經學習到的舊知識。災難性遺忘(catastrophic forgetting)現象是增量學習的最大挑戰之一,也就是模型在學習新知識的同時舊知識會出現嚴重的遺忘,進而導緻模型在舊任務上性能下降。早期工作通過存儲少量舊資料或者設計正則損失函數來維持模型在舊任務上的精度,然而這不可避免地限制了在新任務上的學習能力,如圖1。是以大多數增量方法最後會陷入新舊任務之間的拔河遊戲(零和遊戲)–一方獲得精度的同時會讓另一方損失精度。這個挑戰在域增量(Domain-incremental learning)問題上尤為明顯,不同域的知識可能很難在同一個空間中共存。此外儲存舊任務的資料會占用大量存儲空間,并且有隐私問題和新舊資料量不平衡的問題。是以本工作從實際應用需求出發,聚焦在無存儲樣本的域增量學習任務。

在本工作中,我們打破成規提出一個雙赢政策來解決域增量問題,通過學習跨域獨立的Prompts使得模型在每個域都得到最佳性能而沒有任何互相幹擾,并将學習到的Prompts存儲來消除災難性遺忘問題。所提出的新的增量模式僅僅為每個任務增加微不足道的參數(Prompts)學習目前域的知識,而預訓練網絡的其餘部分都是當機固定的,是以非常簡單且有效。為了在推理階段選擇合适的Prompts,每個階段的訓練資料特征都會用K-Means計算得到域中心作為這個域的表示。在推理時,對于一個樣本,我們先提取這個樣本在預訓練模型(ViT)的特征,再将這個特征用K-NN找存儲的最近域中心作為挑選Prompts的依據。假設有S個階段(Session),我們最終會獨立地學習S個域的Prompts,是以本方法命名為S-Prompts。

此外為更好的學習不同域的Prompts,我們提出全新的針對視覺預訓練模型的Prompts學習方法(S-iPrompts)以及針對視覺-語言預訓練模型的Prompts學習方法(S-liPrompts)。本方法在三個标準DIL基準資料集上取得了較高的成績,S-Prompts明顯優于最新的無樣本增量方法(平均精度相對提高30%),甚至對于使用樣本的方法也高出6%精度。S-Prompts僅僅有極微小的參數增加,例如,在S-liPrompts中每個域增加0.03%參數量。

NeurIPS 2022 | S-Prompts:擺脫新舊任務零和遊戲,實作雙赢的域增量學習方法一、概述二、方法三、實驗結果四、方法總結與未來展望

圖1 現有工作和本工作差別

二、方法

NeurIPS 2022 | S-Prompts:擺脫新舊任務零和遊戲,實作雙赢的域增量學習方法一、概述二、方法三、實驗結果四、方法總結與未來展望

圖2 S-liPrompts結構

2.1 S-Prompts架構簡述

S-Prompts的核心思想是借助預訓練模型,對每個域逐個學習Prompts。在增量訓練時,預訓練模型始終是固定的,通過訓練Prompts可以将預訓練模型調整遷移到不同的域中。在這樣設定下,不同的域的知識被編碼進僅有少量參數的Prompts中,這樣不僅避免了存儲舊樣本,同時可以極大地減少災難性遺忘。

然而這種設計在推理時需要對給定的樣本挑選合适的Prompts。由于我們已經有了預訓練模型,那麼預訓練模型本身可以幫助選擇合适的Prompts。具體而言,如圖2所示,我們應用K-Means來得到每個域的訓練資料的特征中心,這些特征是直接使用預訓練模型提取的,并沒有應用Prompts。在推理時,我們直接使用K-NN來查詢應該使用哪個域的Prompts。由于域增量任務的特征往往差别很大,這種簡單的做法可以在DIL中獲得良好的性能。

2.2 圖像Prompts(S-iPrompts)學習政策

在S-iPrompts的方法中,對于一個域S,我們使用一組獨立的連續可學習參數(即Prompts) P s i ∈ R L i × D i P^i_s \in \mathbb{R}^{L_{i} \times D_i} Psi​∈RLi​×Di​ 作為預訓練ViT的輸入的一部分,其中 L i ∈ R L_i \in \mathbb{R} Li​∈R 和 D i ∈ R D_i \in \mathbb{R} Di​∈R 分别是Prompts長度和次元。

如圖2所示,給定域 s s s 的圖檔 x x x ,ViT的輸入為 x = [ x i m g , P s i , x c l s ] x = [x_{img}, P^i_s ,x_{cls}] x=[ximg​,Psi​,xcls​] ,其中 x i m g x_{img} ximg​ 是圖檔tokens, x c l s x_{cls} xcls​ 是預訓練模型ViT的class tokens。當在新的域 s + 1 s+1 s+1 上增量訓練時,會添加一組新的獨立的Prompts P s + 1 i P^i_{s+1} Ps+1i​ 。是以,按順序學習所有域會産生一個域的Prompt Pool。Prompt Pool可以定義為 P i = { P 1 i , P 2 i , . . . , P S i } \mathcal{P}^i =\left \{ P^i_1, P^i_2, ...,P^i_S \right \} Pi={P1i​,P2i​,...,PSi​} 。

對于Classifier,每個session都會學習單獨分類器并且存儲下來,在推理時挑選對應的classifier。對于ViT,分類器就是全連結層,表示為 [ W s , b s ] \left [ W_s, b_s \right ] [Ws​,bs​] , 其中 W s ∈ R C × D i W_s \in {R}^{C \times D_i} Ws​∈RC×Di​ , b s ∈ R C b_s \in \mathbb{R}^{C} bs​∈RC , C C C 分别是特征次元和總共的類别數量。

每個增量階段都有獨立的分類器,是以我們也有一個分類器池 P f c = { [ W 1 , b 1 ] , [ W 2 , b 2 ] , . . . , [ W S , b S ] } \mathcal{P}^{fc} =\left \{ \left [ W_1, b_1 \right ] , \left [ W_2, b_2 \right ], ...,\left [ W_S, b_S \right ] \right \} Pfc={[W1​,b1​],[W2​,b2​],...,[WS​,bS​]} 。

2.3 語言-圖像Prompts(S-liPrompts)學習政策

S-liPrompts是為了能夠将現在很多的視覺-語言預訓練模型,例如CLIP,更好地增量遷移到下遊任務上。對于階段 s s s,我們使用 M M M 個可學習的向量 v v v 作為語言端的prompts P s l = { v 1 , v 2 , . . . , v M } ∈ R L l × D l P^l_s=\left \{v_1, v_2, ..., v_M \right \} \in \mathbb{R} ^{L_{l} \times D_l} Psl​={v1​,v2​,...,vM​}∈RLl​×Dl​ ,其中 L l , D l L_{l}, D_l Ll​,Dl​ 分别是Prompts的長度和次元。

對于第 j j j 個類, 語言編碼器(text encoder)的全部輸入為 t j = { P s l , c j } t_j = \left \{P^l_s, c_j\right \} tj​={Psl​,cj​} , 其中 c j c_j cj​ 是第 j j j 個類的類别名稱編碼。

語言Prompts同樣和各自的域相關聯,在訓練完所有增量階段後,可以同樣得到一個Prompt Pool存放所有的語言Prompts,如 P l = { P 1 l , P 2 l , . . . , P S l } \mathcal{P}^l = \left \{ P^l_1, P^l_2, ..., P^l_S \right \} Pl={P1l​,P2l​,...,PSl​} 。

CLIP的語言編碼器 g g g 以上文定義的 t j t_j tj​ 作為輸入并且輸出一個向量表示作為某個類的特征。

令 f ( x ) f(x) f(x) 為視覺編碼器 f f f 提取的圖檔 x x x 的特征, { g ( t j ) } j C \left \{ g(t_j) \right \}^C_j {g(tj​)}jC​ 是使用文字編碼器 g ( . ) g(.) g(.) 提取的類别 j j j 的特征,CLIP分類器使用如下公式計算預測機率。

p ( y j ∣ x ) = exp ⁡ ( ⟨ f ( x ) , g ( t j ) ⟩ ) ∑ k = 1 C exp ⁡ ( ⟨ f ( x ) , g ( t k ) ⟩ ) p(y_j|x)=\frac{\exp (\left \langle f(x),g(t_j) \right \rangle )}{ {\textstyle \sum_{k=1}^{C}}\exp (\left \langle f(x),g(t_k) \right \rangle) } p(yj​∣x)=∑k=1C​exp(⟨f(x),g(tk​)⟩)exp(⟨f(x),g(tj​)⟩)​

三、實驗結果

在實驗設定上,本工作選擇了三個在DIL任務上有代表性的大型評測基準:CDDB,CORE50和DomainNet。所有方法均使用相同預訓練ViT-B/16或者同性能的Backbone(針對DyTox使用的是預訓練ConViT)。

表1、2和3結果展示了所提出的S-iPrompts和S-liPrompts極大程度地超越了已有的其他無樣本增量方法。甚至S-liPrompts得到了相對30%的精度提升。此外相對于存樣本的方法,所提出的S-Prompts在不存樣本的情況下也取得了6%左右的精度提升。

NeurIPS 2022 | S-Prompts:擺脫新舊任務零和遊戲,實作雙赢的域增量學習方法一、概述二、方法三、實驗結果四、方法總結與未來展望

表1 CDDB資料集結果

NeurIPS 2022 | S-Prompts:擺脫新舊任務零和遊戲,實作雙赢的域增量學習方法一、概述二、方法三、實驗結果四、方法總結與未來展望

表2 CORE50資料集結果

NeurIPS 2022 | S-Prompts:擺脫新舊任務零和遊戲,實作雙赢的域增量學習方法一、概述二、方法三、實驗結果四、方法總結與未來展望

表3 DomainNet資料集結果

四、方法總結與未來展望

在本工作中我們提出使用Prompts來解決域增量學習中的災難性遺忘現象,并且在多個資料集取得優秀的性能表現。盡管所提出的方法同時适用任務增量任務(TIL),但是還無法做類增量問題(CIL)。此外Prompts的設計還存在較大的優化空間,例如Prompts目前隻作為最初的輸入,但是Prompts的放置位置有很多可能性。最後,Prompt Tuning作為Efficient Finetuning技術的一種,未來可能會在增量學習中得到更深入的應用,特别是大規模預訓練模型的興起,會進一步推進對增量學習問題的研究。

-The End-

關于我“門”

将門是一家以專注于發掘、加速及投資技術驅動型創業公司的新型創投機構,旗下涵蓋将門創新服務、将門-TechBeat技術社群以及将門創投基金。

将門成立于2015年底,創始團隊由微軟創投在中國的創始團隊原班人馬建構而成,曾為微軟優選和深度孵化了126家創新的技術型創業公司。

如果您是技術領域的初創企業,不僅想獲得投資,還希望獲得一系列持續性、有價值的投後服務,歡迎發送或者推薦項目給我“門”:

[email protected]

繼續閱讀