机器学习 by 李宏毅(11)

Domain Adaption

训练一个书写数字识别模型不是难题,但是如果 Training Data 是黑白图像,Testing Data 是彩色图像上的正确率非常低。

Domain Shift:训练数据和测试数据具有不同的 distribution 。

1

Domain Adaptation 就是解决 Domain Shift 的问题导致的model性能变差,也可以看做 Transfer Learning 的一部分

Domain Shift

  • 模型输入数据的变化只是 Domain Shift 的一种类型,而输出的数据分布也可能有变化,也就是输出某个数字的几率特别大
  • 输出和输出的关系也可能发生变化

2

我们定义 Training Data 来自 Source Domain,Testing data 来自 Target Domain

3

Domain Adaptation

Training Data 是来自 Source Domain labeled data,希望 Training data 训练得到的模型可以用在不同的Domain 上,所以在训练时必须对 Target Domain 有一定的了解

  • 有少量来自 Target Domain labeled data,可以用这些数据 fine-tune trained model。由于数据很少,需要小心 model 的 overfit
  • 有大量来自 Target Domain unlabeled data,基本的想法是训练一个特征抽取模型,可以无视不同 Domain 的差异,抽取出一致的 feature distribution
  • 有少量来自 Target Domain unlabeled data
  • 对Target Domain 一无所知

4

如何找出 Feature Extractor Network?

可以把一般的 Classifier 分为 feature extractor 和 label predictor 两部分

procedure:

  1. Source Domain Data 直接输出到 model 里按照一般的方法进行训练
  2. Target Domain Data 没有任何标注,所以输入model,把feature extractor 的 output 拿出来,要与 Source Domain 的 feature extractor output 没有差异,图中分别同蓝色点和红色点表示

5

Domain Adversarial Training

需要训练一个 Domain Classifier,输入是 feature extractor output,可以判断 feature 来自 Source Domain 还是 Target Domain。

而 Feature Extractor 需要 Fool Domain Classifier,本质上是 GAN

6

但是 Feature Extractor 完全可以无视输入,只输出 0 就可以骗过 Domain Classifier,所以需要 Label Predictor 阻止这种情况的发生

\(\theta_f\) 表示 Feature Extractor,\(\theta_p\) 表示 Label Predictor,\(\theta_d\) 表示Domain Classifier,L 表示 Source Domain Data 的 loss,\(L_d\) 表示 Domain Classifier 的 binary Loss \[ \theta_p^*=\underset{\theta_p}{min\ }L \\ \theta_d^* = \underset{\theta_d}{min\ }L_d \\ \theta_f^* = \underset{\theta_f}{min\ } L-L_d \] 实验结果如图

7

但是仅仅 \(-L_d\) 未必是最好的做法,目的是要 \(L_d\) 变大,也就是 Feature Extractor 欺骗 Domain Classifier 无法区分Source 和 Target,但是如果最大化 \(L_d\),有可能的结果是 Feature Extractor output 使得 Source 被分类为 Target,Target 被分类为 Source,但这不符合我们的目标,我们需要让全部 data 分为同一类别

Limitation

  • class 1 和 class 2 用不同的图形表示,Target 没有 class label 用一个图形表示,目标是使得分布越接近越好,但是接近的方式可以有很多种

8

显然,右边的接近方式更好,更易于分类,所以要让 Target 远离分界。简单的方式就是让分类的结果更集中,分类的置信度越高越好

9

  • 不同 Domain 里的类别可能并不一样,因为 Target Domain 没有label

10

Domain Generalization

对 Target Domain 一无所知

  • 假设 Training Data 来自各种不同的 Domain,期待 model 可以学到如何消除Domain带来的影响

11

  • 假设 Training data 只有一个 domain,Testing data 来自多个 domain,可以通过 Training data 生成 多个 domain 的 data

12