本文介绍我们最近的一篇TPAMI工作:Unsupervised Multi-Class Domain Adaptation: Theory, Algorithms, and Practice。 域适应(即domain adaptation)是迁移学习中的重要课题。该课题的目标是:
输入有标签的源域数据和无标签的目标域数据,输出一个适用于目标域的模型。 源域和目标域假设任务相同但是数据分布不同
既然源域和目标域的数据分布不同,该任务的经典解决方法是:
找到一个特征空间,将分布不同的源域和目标域数据映射到该特征空间后,希望源域和目标域的数据分布差异尽可能小; 这样基于源域数据训练的模型,就可以用于目标域数据上
如何找到该特征空间,更具体来说,如何衡量两个域数据分布之间的差异是域适应任务的核心问题。
通过对抗训练的方式实现两个域的数据分布对齐在域适应任务中被广泛采用[1]。近期很多对抗域适应的算法采用特征映射网络和分类器进行对抗的方式[2,3,4,5]。虽然基于分类器进行对抗训练的方法[2,3,4,5]取得了不错的结果,但是这些算法与现有理论并不是完全吻合的;也就是说,理论和算法之间存在一定的差距。 出于此目的,我们对现有的域适应理论进行了改进,使其可以更好的支撑现有算法。同时,基于该理论框架,我们提出了一系列新的算法,并在closed set, partial, and open set 域适应三个任务上验证了其有效性。该文章的要点可以总结如下:
- 理论方面:提出了Multi-Class Scoring Disagreement (MCSD) divergence来衡量两个域数据分布之间的差异;其中MCSD可以充分衡量两个scoring hypotheses(可以理解为分类器) 之间的差异。基于MCSD divergence, 我们提出了新的Adaptation Bound, 并详细讨论了我们理论框架和之前框架的关系。
- 算法方面:基于MCSD divergence 的理论,我们提出了一套新的代码框架Multi-class Domain-adversarial learning Networks (McDalNets)。McDalNets的不同实现与近期的流行方法相似或相同,因此从理论层面支撑了这些方法 [2,3,4,5]。此外,我们提出了一个新的算法SymmNets-V2, 该方法是我们之前会议工作[3]的改进版本。
- 实践方面:我们在closed set, partial, and open set 三种实验设置下验证了我们提出方法的有效性。代码链接:Code
理论方面:
如上文所言,如何衡量两个域之间的差异是域适应任务的核心问题。为了更加细粒度地衡量两个域之间的差异,我们引入了如下的 MCSD divergence:
其中
充分衡量了scoring functions
在域
上的不一致性(相对于下面将要描述的其他度量方法).
的定义如下:
是ramp loss,
指代absolute margin function
的第
个值。上述定义有些复杂,我们接下来对其直观描述:
的每一列
计算了violations of absolute margin function
,进而
度量了
之间margin violations的差异,一个直观的例子如Fig 1(c)所示:
到了这里,大家应该会疑惑:这个MCSD divergence 看上去挺复杂的,它有什么好处?MCSD的优势如下:
理论角度:MCSD可以 充分度量两个scoring functions 的差异!!同时导出后续的bound.
算法角度:对scoring functions 的差异的充分度量可以直接支撑基于分类器进行对抗训练的方法[2,3,4,5].
为了展示MCSD对scoring functions 差异的充分度量,我们基于absolute margin function
引入其他domain divergence [6,7] 的变种或等效形式。
是absolute margin-based variant of margin disparity (MD) [6]:
, where
是relative margin function. 进而基于
得到的divergence 是MDD 的变种。
是ablolute margin-based equivalent of the hypothesis disagreement (HD) [7]
. 进而基于
得到的divergence 等效于
作为3种不同的度量scoring functions差异的方法,其直观对比如Fig 1所示,可以总结如下:
- 采用0-1二值loss只衡量了 的最终类别预测是否一致。
- 相对 , 通过引入margin 在0和1之间做了一个平滑的过渡。
- 以上两者都只考虑了scoring functions的部分输出, 首次将scoring functions 的所有输出值加以考虑。故而MCSD可以充分度量scoring functions 的差异。
基于MCSD divergence, 我们可以得到如下的bound:
,其中
是targer error,
是source error,
可视为与数据集合hypothesis space相关的常数。相应的PAC bound也可以导出。
总的来说,我们提出了一种MCSD divergence 来充分度量两个scoring functions的差异,进而提出了一种新的adaptation bound. 那么充分度量两个scoring functions的差异有什么好处呢?后续的对比实验经验性的回答了该问题。
算法方面:
上述理论可以推导出一系列的算法,我们将这些算法统一命名为McDalNets. 基于上述bound, 为了最小化target error
,我们需要找到可以最小化
的feature extractor
以及可以最小化source error
的
和
. 将
展开成
的形式可以得到如下的优化目标:
其中
分别是分布
经由
映射得到的特征分布。该优化目标如下图所示:
上述目标仍然难以直接优化,因为ramp loss
会导致梯度消失的问题。为了便于优化,我们引入了一些MCSD的替代度量方法。这些替代度量方法应该具有如下特点:
- 当 在domain 上的输出越趋于一致,替代度量方法的值越小
- 当 在domain 上的输出越差异越大,替代度量方法的值越大
我们在本文中采用了三种MCSD的替代度量方法,分别是:
其中
是softmax函数,
是KL散度,
是交叉熵函数.
其他具有上述两点特点且便于优化的函数都可以用来作为MCSD的替代度量方法。当采用
loss 作为替代度量时,McDalNet与MCD [2] 方法极其相似。 需要强调的是,虽然MCD算法是从
divergence [7] 推导而出的,但是MCD算法与
divergence存在明显gap:MCD算法采用L_1 loss 衡量了classifiers ouputs 在element-wise的差异,而
divergence 只考虑了classifiers 类别预测的不一致性。考虑到MCSD是基于对classifiers ouputs 在element-wise的差异的度量,因此MCSD divergence 可以更直接,更紧密的支撑MCD这类基于classifiers outputs 差异做对抗训练的方法。
类似
,我们也可以基于
和
得到对应的类似McDalNet的算法。其中基于
得到的方法完全等效于DANN [1], 基于
得到的方法是MDD [6] 的一个变种。
我们将不同McDalNet 的算法在标准的域适应数据集上进行对比,结果如下图所示:
除了上述的McDalNet框架,基于MCSD divergence, 我们还引入了一个Domain-Symmetric Networks (SymmNets)的新框架,如下图所示。
该框架是基于CVPR 的论文[3]做的改进,因此我们称之为SymmNets-V2. 相对于McDalNets, SymmNets-V2 没有额外的task classifier,而是将其与classifiers for 域对齐进行了合并。该方法在网络结构上的鲜明特点是将两个classifiers拼接到一起,并用拼接得到的classifier用作域对齐;通过这种方式,我们赋予了两个classifiers 明确的domain 信息,同时取得了更优的实验结果。SymmNets-V2 的优化目标如下:
其中
是分类损失,用来赋予
类别信息,
用来增大
的输出差异,
和
分别用来减小
在源域数据和目标域数据上的输出差异。其具体定义和与MCSD的联系请参考论文。
对于熟悉DANN [1] 的读者,可以将SymmNets看做将category information 引入DANN的直接扩展。具体来说,如果我们分别将
中的所有类别当成整体,那么整体化之后的
就分别对应着DANN 二分类domain classifier 中的源域和目标域;这样SymmNets中的增大/减小
的输出差异就对应着DANN中的domain discrimination/domain confusion. 将DANN 二分类domain classifier 中的源域和目标域扩展成由
拼接成的2K 分类器,可以为在域对齐过程中引入category information做好模型结构准备。
实践方面:我们在closed set, partial, and open set domain adaptation三个任务共七个数据集上验证了我们提出的McDalNets和SymmNets的有效性。相对closed set 的任务,partial and open set domain adaptation任务中的难度增大很大程度是两个域中共享类别的样本与其中一个域中独有类别的样本在adaptation 过程中发生了错误对齐带来的;因此SymmNets中对category information 的引入和对category level alignment 的促进可以极大的缓解该错误对齐现象,从而对partial 和open set domain adaptation带来帮助。最后,我们通过如下的t-SNE可视化来说明我们提出的SymmNets的有效性。
[1] Domain-Adversarial Training of Neural Networks, JMLR16
[2] Maximum Classifier Discrepancy for Unsupervised Domain Adaptation, CVPR18
[3] Domain-Symmetric Networks for Adversarial Domain Adaptation, CVPR19
[4] Unsupervised Domain Adaptation via Regularized Conditional Alignment, ICCV19
[5] Sliced wasserstein discrepancy for unsupervised domain adaptation, CVPR19
[6] Bridging Theory and Algorithm for Domain Adaptation, ICML19
[7] A theory of learning from different domains,ML10