@[TOC](【论文笔记】Asymmetric Tri-training for Unsupervised Domain Adaptation))
论文地址:http://cn.arxiv.org/pdf/1702.08400v3
代码地址:https://github.com/corenel/pytorch-atda#pytorch-atda
基本介绍
ATDA 解决的问题: 源数据有类别标签,目标域没有标签;如何去学习到目标域的特征表达?在不同域都具有良好的效果的分类器是不存在的,文中的解决方法如下:用两个网络来去标记目标域的标签,和另外一个网络学习目标域的特征表达。
主要结构
- 主要流程如下:
- 首先用源域数据(有标签)训练两个分类器。
- 用源域训练好的分类器给目标域的数据打标签,只有两个分类器的预测标签是一样的,并且至少有一个是大于给定阈值的,该标签才是可靠的。
- 用目标域的数据(有伪标签)去训练一个新的分类器。
- 流程图
这张图就是上面的基本流程的具体化,主要有四个部分:F(源域和目标域共享的网络),F1,F2,Ft。 整个网络具体训练的过程如下:
- 训练过程:
- 首先使用source data 训练F, F1, F2, Ft
- N 代表重采样目标域的样本个数,用第一步训练好的网络预测目标域的伪标签
- 用源域数据 + 目标域数据(含有伪标签)一起训练F, F1, F2
- 用目标域数据训练Ft,不断重采样目标域的数据,回到步骤2重复训练
- 训练过程中需要注意的问题:
- 为了保证训练过程中F1, F2 两条支路训练的分类器不一样, 文章对F1, F2两支路的参数进行了约束。在损失函数中加入了 |W1TW2|, W1, W2是F1,F2全连接层的第一层参数。 最终的损失函数为:
- 只有两支路预测的标签是相同且至少有一个大于阈值(实验设置为0.9或0.95),该标签才有效(那如果不满足这两个条件呢,这个标签无效? 这一块没太懂)。