本文介紹我們最近的一篇TPAMI工作:Unsupervised Multi-Class Domain Adaptation: Theory, Algorithms, and Practice。 域适應(即domain adaptation)是遷移學習中的重要課題。該課題的目标是:
輸入有标簽的源域資料和無标簽的目标域資料,輸出一個适用于目标域的模型。 源域和目标域假設任務相同但是資料分布不同
既然源域和目标域的資料分布不同,該任務的經典解決方法是:
找到一個特征空間,将分布不同的源域和目标域資料映射到該特征空間後,希望源域和目标域的資料分布差異盡可能小; 這樣基于源域資料訓練的模型,就可以用于目标域資料上
如何找到該特征空間,更具體來說,如何衡量兩個域資料分布之間的差異是域适應任務的核心問題。
通過對抗訓練的方式實作兩個域的資料分布對齊在域适應任務中被廣泛采用[1]。近期很多對抗域适應的算法采用特征映射網絡和分類器進行對抗的方式[2,3,4,5]。雖然基于分類器進行對抗訓練的方法[2,3,4,5]取得了不錯的結果,但是這些算法與現有理論并不是完全吻合的;也就是說,理論和算法之間存在一定的差距。 出于此目的,我們對現有的域适應理論進行了改進,使其可以更好的支撐現有算法。同時,基于該理論架構,我們提出了一系列新的算法,并在closed set, partial, and open set 域适應三個任務上驗證了其有效性。該文章的要點可以總結如下:
- 理論方面:提出了Multi-Class Scoring Disagreement (MCSD) divergence來衡量兩個域資料分布之間的差異;其中MCSD可以充分衡量兩個scoring hypotheses(可以了解為分類器) 之間的差異。基于MCSD divergence, 我們提出了新的Adaptation Bound, 并詳細讨論了我們理論架構和之前架構的關系。
- 算法方面:基于MCSD divergence 的理論,我們提出了一套新的代碼架構Multi-class Domain-adversarial learning Networks (McDalNets)。McDalNets的不同實作與近期的流行方法相似或相同,是以從理論層面支撐了這些方法 [2,3,4,5]。此外,我們提出了一個新的算法SymmNets-V2, 該方法是我們之前會議工作[3]的改進版本。
- 實踐方面:我們在closed set, partial, and open set 三種實驗設定下驗證了我們提出方法的有效性。代碼連結:Code
理論方面:
如上文所言,如何衡量兩個域之間的差異是域适應任務的核心問題。為了更加細粒度地衡量兩個域之間的差異,我們引入了如下的 MCSD divergence:
其中
充分衡量了scoring functions
在域
上的不一緻性(相對于下面将要描述的其他度量方法).
的定義如下:
是ramp loss,
指代absolute margin function
的第
個值。上述定義有些複雜,我們接下來對其直覺描述:
的每一列
計算了violations of absolute margin function
,進而
度量了
之間margin violations的差異,一個直覺的例子如Fig 1(c)所示:
到了這裡,大家應該會疑惑:這個MCSD divergence 看上去挺複雜的,它有什麼好處?MCSD的優勢如下:
理論角度:MCSD可以 充分度量兩個scoring functions 的差異!!同時導出後續的bound.
算法角度:對scoring functions 的差異的充分度量可以直接支撐基于分類器進行對抗訓練的方法[2,3,4,5].
為了展示MCSD對scoring functions 差異的充分度量,我們基于absolute margin function
引入其他domain divergence [6,7] 的變種或等效形式。
是absolute margin-based variant of margin disparity (MD) [6]:
, where
是relative margin function. 進而基于
得到的divergence 是MDD 的變種。
是ablolute margin-based equivalent of the hypothesis disagreement (HD) [7]
. 進而基于
得到的divergence 等效于
作為3種不同的度量scoring functions差異的方法,其直覺對比如Fig 1所示,可以總結如下:
- 采用0-1二值loss隻衡量了 的最終類别預測是否一緻。
- 相對 , 通過引入margin 在0和1之間做了一個平滑的過渡。
- 以上兩者都隻考慮了scoring functions的部分輸出, 首次将scoring functions 的所有輸出值加以考慮。故而MCSD可以充分度量scoring functions 的差異。
基于MCSD divergence, 我們可以得到如下的bound:
,其中
是targer error,
是source error,
可視為與資料集合hypothesis space相關的常數。相應的PAC bound也可以導出。
總的來說,我們提出了一種MCSD divergence 來充分度量兩個scoring functions的差異,進而提出了一種新的adaptation bound. 那麼充分度量兩個scoring functions的差異有什麼好處呢?後續的對比實驗經驗性的回答了該問題。
算法方面:
上述理論可以推導出一系列的算法,我們将這些算法統一命名為McDalNets. 基于上述bound, 為了最小化target error
,我們需要找到可以最小化
的feature extractor
以及可以最小化source error
的
和
. 将
展開成
的形式可以得到如下的優化目标:
其中
分别是分布
經由
映射得到的特征分布。該優化目标如下圖所示:
上述目标仍然難以直接優化,因為ramp loss
會導緻梯度消失的問題。為了便于優化,我們引入了一些MCSD的替代度量方法。這些替代度量方法應該具有如下特點:
- 當 在domain 上的輸出越趨于一緻,替代度量方法的值越小
- 當 在domain 上的輸出越差異越大,替代度量方法的值越大
我們在本文中采用了三種MCSD的替代度量方法,分别是:
其中
是softmax函數,
是KL散度,
是交叉熵函數.
其他具有上述兩點特點且便于優化的函數都可以用來作為MCSD的替代度量方法。當采用
loss 作為替代度量時,McDalNet與MCD [2] 方法極其相似。 需要強調的是,雖然MCD算法是從
divergence [7] 推導而出的,但是MCD算法與
divergence存在明顯gap:MCD算法采用L_1 loss 衡量了classifiers ouputs 在element-wise的差異,而
divergence 隻考慮了classifiers 類别預測的不一緻性。考慮到MCSD是基于對classifiers ouputs 在element-wise的差異的度量,是以MCSD divergence 可以更直接,更緊密的支撐MCD這類基于classifiers outputs 差異做對抗訓練的方法。
類似
,我們也可以基于
和
得到對應的類似McDalNet的算法。其中基于
得到的方法完全等效于DANN [1], 基于
得到的方法是MDD [6] 的一個變種。
我們将不同McDalNet 的算法在标準的域适應資料集上進行對比,結果如下圖所示:
除了上述的McDalNet架構,基于MCSD divergence, 我們還引入了一個Domain-Symmetric Networks (SymmNets)的新架構,如下圖所示。
該架構是基于CVPR 的論文[3]做的改進,是以我們稱之為SymmNets-V2. 相對于McDalNets, SymmNets-V2 沒有額外的task classifier,而是将其與classifiers for 域對齊進行了合并。該方法在網絡結構上的鮮明特點是将兩個classifiers拼接到一起,并用拼接得到的classifier用作域對齊;通過這種方式,我們賦予了兩個classifiers 明确的domain 資訊,同時取得了更優的實驗結果。SymmNets-V2 的優化目标如下:
其中
是分類損失,用來賦予
類别資訊,
用來增大
的輸出差異,
和
分别用來減小
在源域資料和目标域資料上的輸出差異。其具體定義和與MCSD的聯系請參考論文。
對于熟悉DANN [1] 的讀者,可以将SymmNets看做将category information 引入DANN的直接擴充。具體來說,如果我們分别将
中的所有類别當成整體,那麼整體化之後的
就分别對應着DANN 二分類domain classifier 中的源域和目标域;這樣SymmNets中的增大/減小
的輸出差異就對應着DANN中的domain discrimination/domain confusion. 将DANN 二分類domain classifier 中的源域和目标域擴充成由
拼接成的2K 分類器,可以為在域對齊過程中引入category information做好模型結構準備。
實踐方面:我們在closed set, partial, and open set domain adaptation三個任務共七個資料集上驗證了我們提出的McDalNets和SymmNets的有效性。相對closed set 的任務,partial and open set domain adaptation任務中的難度增大很大程度是兩個域中共享類别的樣本與其中一個域中獨有類别的樣本在adaptation 過程中發生了錯誤對齊帶來的;是以SymmNets中對category information 的引入和對category level alignment 的促進可以極大的緩解該錯誤對齊現象,進而對partial 和open set domain adaptation帶來幫助。最後,我們通過如下的t-SNE可視化來說明我們提出的SymmNets的有效性。
[1] Domain-Adversarial Training of Neural Networks, JMLR16
[2] Maximum Classifier Discrepancy for Unsupervised Domain Adaptation, CVPR18
[3] Domain-Symmetric Networks for Adversarial Domain Adaptation, CVPR19
[4] Unsupervised Domain Adaptation via Regularized Conditional Alignment, ICCV19
[5] Sliced wasserstein discrepancy for unsupervised domain adaptation, CVPR19
[6] Bridging Theory and Algorithm for Domain Adaptation, ICML19
[7] A theory of learning from different domains,ML10