天天看点

【论文笔记】Asymmetric Tri-training for Unsupervised Domain Adaptation

@[TOC](【论文笔记】Asymmetric Tri-training for Unsupervised Domain Adaptation))

论文地址:http://cn.arxiv.org/pdf/1702.08400v3

代码地址:https://github.com/corenel/pytorch-atda#pytorch-atda

基本介绍

ATDA 解决的问题: 源数据有类别标签,目标域没有标签;如何去学习到目标域的特征表达?在不同域都具有良好的效果的分类器是不存在的,文中的解决方法如下:用两个网络来去标记目标域的标签,和另外一个网络学习目标域的特征表达。

主要结构

  1. 主要流程如下:
    1. 首先用源域数据(有标签)训练两个分类器。
    2. 用源域训练好的分类器给目标域的数据打标签,只有两个分类器的预测标签是一样的,并且至少有一个是大于给定阈值的,该标签才是可靠的。
    3. 用目标域的数据(有伪标签)去训练一个新的分类器。
      【论文笔记】Asymmetric Tri-training for Unsupervised Domain Adaptation
  2. 流程图
    【论文笔记】Asymmetric Tri-training for Unsupervised Domain Adaptation

这张图就是上面的基本流程的具体化,主要有四个部分:F(源域和目标域共享的网络),F1,F2,Ft。 整个网络具体训练的过程如下:

【论文笔记】Asymmetric Tri-training for Unsupervised Domain Adaptation
  1. 训练过程:
    1. 首先使用source data 训练F, F1, F2, Ft
    2. N 代表重采样目标域的样本个数,用第一步训练好的网络预测目标域的伪标签
    3. 用源域数据 + 目标域数据(含有伪标签)一起训练F, F1, F2
    4. 用目标域数据训练Ft,不断重采样目标域的数据,回到步骤2重复训练
  2. 训练过程中需要注意的问题:
    1. 为了保证训练过程中F1, F2 两条支路训练的分类器不一样, 文章对F1, F2两支路的参数进行了约束。在损失函数中加入了 |W1TW2|, W1, W2是F1,F2全连接层的第一层参数。 最终的损失函数为:
      【论文笔记】Asymmetric Tri-training for Unsupervised Domain Adaptation
    2. 只有两支路预测的标签是相同且至少有一个大于阈值(实验设置为0.9或0.95),该标签才有效(那如果不满足这两个条件呢,这个标签无效? 这一块没太懂)。

继续阅读