量化百科

《小王爱迁移》系列之九:开放集迁移学习(Open Set)

由iquant创建,最终由iquant 被浏览 4 用户

这篇文章是刚刚开完的ICCV 17的一篇,文章的setting是特别新颖的,还获得了今年的ICCV Marr Prize 提名。所以很值得一看。

关于代码以及其他

作者并没有提供这个文章的代码,因为他说是在公司开发的,要拿到公司的许可。其实文章的方法本身并不难,复现难就难在:有大量的细节被省略了。比如,变换矩阵 W 到底是怎么求解的?怎么在有未知类的情况下训练SVM?测试数据怎么分的,等等。我根据自己的理解,对文章中一些关键步骤用Matlab进行了实现,地址在这里。期待作者公开代码。

Motivation


现有的domain adaptation都针对的是一个“封闭”的任务(close set),就是说,source和target中的类别是完全一样的:source有几类,target就有几类。作者在这里说,这些方法都只是理想状态下的domain adaptation。而真正的环境中,source和target往往只会共享一些类的信息,而不是全部。现有的方法都不能很好地解决这个问题。简而言之:

  • 原来的:source有5类,target也是这完全相同的5类,叫做closed set;

  • 本文:source有5类,target只共享了其中某些类,还有未知类。我自己画个图:

![](data:image/svg+xml;utf8,<svg%20xmlns='http://www.w3.org/2000/svg' width='566' height='273'></svg>)

Method


文章的方法比较简单,容易看懂。整个文章的解决思路大致是这样的:

  • 利用source和target的关系,给target的样本打上标签
  • 并将source转换到和target同一个空间中

两者依次迭代,直到收敛。作者根据target domain是否有label,把问题分成了unsupervised和semi-supervised domain adaptation。然后分开解决。空间变换这一步是共同的。

Unsupervised domain adaptation

作者用 x_{ct} 来标识,target domain中的第 t 个样本是否被标记为类别 cx_{ct} in {0,1} 。同时,因为是个open set,所以,引入一个 o_t 来标识第 t 个样本是否为未知类别(outlier;unknown), o_t in {0,1} 。接下来,把这个问题直接表示成了一个二值的整数规划问题:

![](data:image/svg+xml;utf8,<svg%20xmlns='http://www.w3.org/2000/svg' width='524' height='292'></svg>)

其中有两个约束:

  • 第一个约束是说,对于任意target中的一个样本,要么 x_{ct}=1 ,要么 o_t=1 。也就是,这个样本要么是属于已知类别 c ,要么是未知类别。
  • 第二个约束是说,对于任意一个已知类别 c ,至少要有一个 x_{ct} =1 。也就是,每个类别下都至少要有一个sample。不会存在有一个已知的共享类别,在target下找不到。

这里的 d_{ct}=||S_c-T_t||^2_2 是一个距离,表示source中属于类别cc的所有sample的均值与target中第 t 个sample的差异。

这两个约束都非常好理解。这就是这个问题的全部形式。解决方法可以按照作者文章说的调用SCIP包就行。

Semi-supervised domain adaptation


要处理semi-supervised情况,只需要在现有的unsupervised情况下,添加那些有label的target的约束信息。作者为了达到这个目的,引入了一个新的变量 x_{hat{c}_tt}=1 。这里的 hat{c}_t 是有label这部分的target label。作者把原来的优化目标改造成了

![](data:image/svg+xml;utf8,<svg%20xmlns='http://www.w3.org/2000/svg' width='539' height='89'></svg>)

其中 d_{cc'}=||S_c-S_{c'}||^2_2 。作者没有说这是什么意思,推测 S_c' 应该是有label的那部分target类别 c 的中心。

这个问题是一个二次规划问题,作者又运用数学知识,改造成了线性问题。详情请看文章。

学习source到target的映射

作者在这里要学习的目标是一个变换矩阵 W ,通过 W 可以把source变换到target空间中。学习的目标函数是:


f(W)=frac{1}{2}sum_{t}^{} sum_{c}x_{ct}||WS_c-T_t||^2_2通过求解偏导数,使其最小,可以求解出 W 。这个部分比较简单。

总的学习过程就是,学习label和学习映射进行交替,直到收敛或者目标值小于某一值即可。作者在文章中说,迭代次数一般少于10次,通常为3~5次即可。

实验

实验部分,作者主体部分是利用office和caltech数据集构造了一个protocol:把这两个数据集中的10个公共类作为两个domain共享的类,然后把office中的31个类剩余21个类作为未知类,分开放入source和target中。为了确保在source和target中的未知类不同。


作者首先对比了现有的一些深度方法,比如DAN、RTN、BP,然后发现提出的方法不仅在open set,在close set上也很好。然后,提取深度特征后,又对比了TCA、GFK、SA、CORAL这几个方法,仍然是作者的方法好。文章做了大量的实验,解释了很多open set下进行domain adaptation的规律。详细请参考文章。

总结


这篇文章提出了一个新颖的问题场景,是现有的domain adaptation方法没有解决过的。作者描述了这个场景的重要性,并且,提出了自己的并不那么复杂的方法来解决。文章的贡献是值得肯定的。可以预见的未来肯定会有一大波工作是延续这个文章进行的。比如,修改现有的方法使这适应这个open set?作者给出了代码地址:Heliot7/open-set-da

Reference

文章原文:Pau Panareda Busto and Juergen Gall. Open Set Domain Adaptation. ICCV 2017.

\