Advertisement

[论文笔记] Domain-Adaptive Few-Shot Learning

阅读量:

原文链接:https://openaccess.thecvf.com/content/WACV2021/papers/Zhao_Domain-Adaptive_Few-Shot_Learning_WACV_2021_paper.pdf

github链接:[GitHub - dingmyu/DAPN: A pytorch implementation of "Domain-Adaptive Few-Shot Learning"](https://github.com/dingmyu/DAPN "GitHub - dingmyu/DAPN: A pytorch implementation of "Domain-Adaptive Few-Shot Learning"") (有bug和文件缺失,无法运行)


Methodology (论文原图,侵删)

Problem Definition

given - a large sample set D_s from a set of source classes C_s in a source domain

- a few-shot sample set D_d from a set of target classes C_d in the target domain

- a test set T from another set of target classes C_t in the target domain,

where 1) C_s ap C_d=mptyset, C_d ap C_t=mptyset, C_sap C_t=mptyset,

  1. data distributions on D_s and D_d are also different

Objective: training a model with D_s and D_d and then evaluating its generalization
ability on T


Episode Training

To form a training episode e_s:

1. randomly choose N_{sc} classes from D_s

2. build two sets of samples from the N_{sc} classes: the support set S_s consists of kimes N_{sc} samples (k samples per class), and the query set Q_s is composed of samples from the same N_{sc} classes.

Also build training episodes from the few-shot sample set Dd (data augmentation method needed).

To form a training episode e_d:

1. randomly choose N_{dc} classes from D_d

2. build two sets of samples from the N_{dc} classes: the support set S_d consists of kimes N_{dc} samples (k samples per class), and the query set Q_d is composed of samples from the same N_{dc} classes.


Feature Extractor

原文代码的 Feature Extractor 使用的是ResNet18


Prototypical Learning

基本思路是基于Prototypical network: Learn a prototype of each class in the support set S_s, and classifies each sample in the query set Q_s based on the distances between each sample and different prototypes.

sample embedding f_{arphi} 和 class prototype p_c^s 的距离可以用以下公式计算(dist 是欧氏距离)

loss function over each episode e_s 就是基于每一个query sample的negative log-probability:

同理,loss function over each episode e_d 就是基于每一个query sample的negative log-probability:


Domain Adversarial Adaptation Module

1. Domain Adaptive Embedding

The embedding module consists of an autoencoder and an attention submodule.

目的: 为了使得到的feature尽可能domain-confused

2. Domain Adaptive Loss

根据文章Conditional Domain Adversarial Network (CDAN) 额外加入一个domain
discriminator D 来处理 source distribution P_s 和 target distribution P_t 之间的 domain adversarial loss:


Domain Discriminative Loss

The features before and after the embedding layer with self-attention are distinguished and confused, respectively:


Adaptive Reweighting Module

用自适应模块将四个loss进行累加


Experiments

[略]在miniImageNet,tieredImageNet和DomainNet 上进行训练

全部评论 (0)

还没有任何评论哟~