天天看點

NeurIPS 2022 | 如何定義測試階段訓練?順序推理和域适應聚類方法

作者:機器之心Pro

機器之心專欄

作者:蘇永怡

華南理工、A*STAR 團隊和鵬城實驗室聯合提出了針對測試階段訓練(TTT)問題的系統性分類準則。

域适應是解決遷移學習的重要方法,目前域适應當法依賴原域和目标域資料進行同步訓練。當源域資料不可得,同時目标域資料不完全可見時,測試階段訓練(Test- Time Training)成為新的域适應方法。目前針對 Test-Time Training(TTT)的研究廣泛利用了自監督學習、對比學習、自訓練等方法,然而,如何定義真實環境下的 TTT 卻被經常忽略,以至于不同方法間缺乏可比性。

近日,華南理工、A*STAR 團隊和鵬城實驗室聯合提出了針對 TTT 問題的系統性分類準則,通過區分方法是否具備順序推理能力(Sequential Inference)和是否需要修改源域訓練目标,對目前方法做了詳細分類。同時,提出了基于目标域資料定錨聚類(Anchored Clustering)的方法,在多種 TTT 分類下取得了最高的分類準确率,本文對 TTT 的後續研究指明了正确的方向,避免了實驗設定混淆帶來的結果不可比問題。研究論文已被 NeurIPS 2022 接收。

NeurIPS 2022 | 如何定義測試階段訓練?順序推理和域适應聚類方法
  • 論文:https://arxiv.org/abs/2206.02721
  • 代碼:https://github.com/Gorilla-Lab-SCUT/TTAC

一、引言

深度學習的成功主要歸功于大量的标注資料和訓練集與測試集獨立同分布的假設。在一般情況下,需要在合成資料上訓練,然後在真實資料上測試時,以上假設就沒辦法滿足,這也被稱為域偏移。為了緩解這個問題,域适應 (Domain Adaptation, DA) 誕生了。現有的 DA 工作要麼需要在訓練期間通路源域和目标域的資料,要麼同時在多個域進行訓練。前者需要模型在做适應 (Adaptation) 訓練期間總是能通路到源域資料,而後者需要更加昂貴的計算量。為了降低對源域資料的依賴,由于隐私問題或者存儲開銷不能通路源域資料,無需源域資料的域适應 (Source-Free Domain Adaptation, SFDA) 解決無法通路源域資料的域适應問題。作者發現 SFDA 需要在整個目标資料集上訓練多個輪次才能達到收斂,在面對流式資料需要及時做出推斷預測的時候 SFDA 無法解決此類問題。這種面對流式資料需要及時适應并做出推斷預測的更現實的設定,被稱為測試時訓練 (Test-Time Training, TTT) 或測試時适應(Test-Time Adaptation, TTA)。

作者注意到在社群裡對 TTT 的定義存在混亂進而導緻比較的不公平。論文以兩個關鍵的因素對現有的 TTT 方法進行分類:

  • 對于資料是流式出現的并需要對目前出現的資料作出及時預測的,稱之為單輪适應協定(One-Pass Adaptation);對于其他不符合以上設定的稱為多輪适應協定(Multi-Pass Adaptation),模型可能需要在整個測試集上進行多輪次的更新後,再進行從頭到尾的推斷預測。
  • 根據是否需要修改源域的訓練損失方程,比如引入額外的自監督分支以達到更有效的 TTT。

這篇論文的目标是解決最現實和最具挑戰性的 TTT 協定,即單輪适應并無需修改訓練損失方程。這個設定類似于 TENT[1]提出的 TTA,但不限于使用來自源域的輕量級資訊,如特征的統計量。鑒于 TTT 在測試時高效适應的目标,該假設在計算上是高效的,并大大提高了 TTT 的性能。作者将這個新的 TTT 協定命名為順序測試時訓練(sequential Test Time Training, sTTT)。

除了以上對不同 TTT 方法的分類外,論文還提出了兩個技術讓 sTTT 更加有效和準确:

  • 論文提出了測試時錨定聚類 (Test-Time Anchored Clustering, TTAC) 方法。
  • 為了降低錯誤僞标簽對聚類更新的影響,論文根據網絡對樣本的預測穩定性和自信度對僞标簽進行過濾。

二、方法介紹

論文分了四部分來闡述所提出的方法,分别是 1)介紹測試時訓練 (TTT) 的錨定聚類子產品,如圖 1 中的 Anchored Clustering 部分;2)介紹用于過濾僞标簽的一些政策,如圖 1 中的 Pseudo Label Filter 部分;3)不同于 TTT++[2]中的使用 L2 距離來衡量兩個分布的距離,作者使用了 KL 散度來度量兩個全局特征分布間的距離;4)介紹在測試時訓練 (TTT) 過程的特征統計量的有效更新疊代方法。最後第五小節給出了整個算法的過程代碼。

NeurIPS 2022 | 如何定義測試階段訓練?順序推理和域适應聚類方法

第一部分 在錨定聚類裡,作者首先使用混合高斯對目标域的特征進行模組化,其中每個高斯分量代表一個被發現的聚類。然後,作者使用源域中每個類别的分布作為目标域分布的錨點來進行比對。通過這種方式,測試資料特征可以同時形成叢集,并且叢集與源域類别相關聯,進而達到了對目标域的推廣。概述來說就是,将源域和目标域的特征分别根據類别資訊模組化成:

NeurIPS 2022 | 如何定義測試階段訓練?順序推理和域适應聚類方法

然後通過 KL 散度度量兩個混合高斯分布的距離,并通過減少 KL 散度來達到兩個域特征的比對。可是,在兩個混合高斯分布上直接求解 KL 散度并沒有閉式解,這導緻了無法使用有效的梯度優化方法。在這篇論文中,作者在源域和目标域中配置設定相同數量的叢集,每個目标域叢集被配置設定給一個源域叢集,這樣就可以将整個混合高斯的 KL 散度求解變成了各對高斯之間的 KL 散度之和。如下式:

NeurIPS 2022 | 如何定義測試階段訓練?順序推理和域适應聚類方法

上式的閉式解形式為:

NeurIPS 2022 | 如何定義測試階段訓練?順序推理和域适應聚類方法

在公式 2 中,源域叢集的參數可以線下收集完,而且由于隻用到了輕量化統計資料,是以不會導緻隐私洩漏問題且隻使用了少量的計算和存儲開銷。對于目标域的變量,涉及到了僞标簽的使用,作者為此設計了一套有效的且輕量的僞标簽過濾政策。

第二部分 僞标簽過濾的政策主要分為兩部分:

1)時序上一緻性預測的過濾:

NeurIPS 2022 | 如何定義測試階段訓練?順序推理和域适應聚類方法

2)根據後驗機率的過濾:

NeurIPS 2022 | 如何定義測試階段訓練?順序推理和域适應聚類方法

最後,使用過濾後的樣本來求解目标域叢集的統計量:

NeurIPS 2022 | 如何定義測試階段訓練?順序推理和域适應聚類方法

第三部分 由于在錨定聚類中,部分被濾除的樣本并沒有參與目标域的估計。作者還對所有測試樣本進行全局特征對齊,類似錨定聚類中對叢集的做法,這裡将所有樣本看作一個整體的叢集,在源域和目标域分别定義

NeurIPS 2022 | 如何定義測試階段訓練?順序推理和域适應聚類方法
NeurIPS 2022 | 如何定義測試階段訓練?順序推理和域适應聚類方法

然後再次以最小化 KL 散度為目标對齊全局特征分布:

NeurIPS 2022 | 如何定義測試階段訓練?順序推理和域适應聚類方法

第四部分 以上三部分都在介紹一些域對齊的手段,但在 TTT 過程中,想要估計一個目标域的分布是不簡單的,因為我們無法觀測整個目标域的資料。在前沿的工作中,TTT++[2]使用了一個特征隊列來存儲過去的部分樣本,來計算一個局部分布來估計整體分布。但這樣不但帶來了記憶體開銷還導緻了精度與記憶體之間的 trade off。在這篇論文中,作者提出了疊代更新統計量的方式來緩解記憶體開銷。具體的疊代更新式子如下:

NeurIPS 2022 | 如何定義測試階段訓練?順序推理和域适應聚類方法

總的來說,整個算法如下算法 1 所示:

NeurIPS 2022 | 如何定義測試階段訓練?順序推理和域适應聚類方法

三、實驗結果

正如引言部分所說,這篇論文中作者非常注重不同 TTT 政策下的不同方法的公平比較。作者将所有 TTT 方法根據以下兩個關鍵因素來分類:1)是否單輪适應協定 (One-Pass Adaptation) 和 2)修改源域的訓練損失方程,分别記為 Y/N 表示需要或不需要修改源域訓練方程,O/M 表示單輪适應或多輪适應。除此之外,作者在 6 個基準的資料集上進行了充分的對比實驗和一些進一步的分析。

如表一所示,TTT++[2]同時出現在了 N-O 和 Y-O 的協定下,是因為 TTT++[2]擁有一個額外的自監督分支,我們在 N-O 協定下将不添加自監督分支的損失,而在 Y-O 下可以正常使用此分子的損失。TTAC 在 Y-O 下也是使用了跟 TTT++[2]一樣的自監督分支。從表中可以看到,在所有的 TTT 協定下所有資料集下,TTAC 均取得到最優的結果;在 CIFAR10-C 和 CIFAR100-C 資料集上,TTAC 都取得了 3% 以上的提升。從表 2 - 表 5 分别是 ImageNet-C、CIFAR10.1、VisDA 上的資料,TTAC 均取到了最優的結果。

NeurIPS 2022 | 如何定義測試階段訓練?順序推理和域适應聚類方法
NeurIPS 2022 | 如何定義測試階段訓練?順序推理和域适應聚類方法
NeurIPS 2022 | 如何定義測試階段訓練?順序推理和域适應聚類方法
NeurIPS 2022 | 如何定義測試階段訓練?順序推理和域适應聚類方法

此外,作者在多個 TTT 協定下同時做了嚴格的消融實驗,清晰地看出了每個部件的作用,如表 6 所示。首先從 L2 Dist 和 KLD 的對比中,可以看出使用 KL 散度來衡量兩個分布具有更優的效果;其次,發現如果單單使用 Anchored Clustering 或單獨使用僞标簽監督提升隻有 14%,但如果結合了 Anchored Cluster 和 Pseudo Label Filter 就可以看到性能顯著提高 29.15% -> 11.33%。這也可以看出每個部件的必要性和有效的結合。

NeurIPS 2022 | 如何定義測試階段訓練?順序推理和域适應聚類方法

最後,作者在正文的尾部從五個次元對 TTAC 展開了充分的分析,分别是 sTTT (N-O)下的累計表現、TTAC 特征的 TSNE 可視化、源域無關的 TTT 分析、測試樣本隊列和更新輪次的分析、以 wall-clock 時間度量計算開銷。還有更多有趣的證明和分析會展示在文章的附錄中。

四、總結

本文隻是粗糙地介紹了 TTAC 這篇工作的貢獻點:對已有 TTT 方法的分類比較、提出的方法、以及各個 TTT 協定分類下的實驗。論文和附錄中會有更加詳細的讨論和分析。我們希望這項工作能夠為 TTT 方法提供一個公平的基準,未來的研究應該在各自的協定内進行比較。

[1] Dequan Wang, Evan Shelhamer, Shaoteng Liu, Bruno Olshausen, and Trevor Darrell. Tent: Fully test-time adaptation by entropy minimization. In International Conference on Learning Representations, 2021.

[2] Yuejiang Liu, Parth Kothari, Bastienvan Delft, Baptiste Bellot-Gurlet, Taylor Mordan, and Alexandre Alahi. Ttt++: When does self-supervised test-time training fail or thrive? In Advances in Neural Information Processing Systems, 2021.

繼續閱讀