Prototypical Networks for Few-shot Learning 论文笔记
前言
本文提出了用于few-shot learning的原型网络(prototypical network),它的基本思想是,在一个embedding空间中,每个类都有一个原型表示(prototypical representation),每个类的点都聚集在这个原型表示周围。具体来说就是,通过一个非线性映射将输入映射到一个embedding空间中,提取每个类别中样本的均值(mean)作为该类在embedding空间中的原型(prototype),那么对于一个embedded point来说,分类就是找到离它最近的类原型。使用属于Bregman divergence的欧几里得距离作为度量距离,在训练时,让测试样本到自己所属的类的原型的距离越近越好,到其它的类原型越远越好;测试时,对测试样本到各个类原型的距离做softmax,预测测试样本的类别标签。
原型网络
在few-shot分类中,给定一个支持集SS,其中有NN个带标签的样本,即S={(x1,y1),...,(xN,yN)}S=\lbrace(x_1,y_1),...,(x_N,y_N) \rbrace,每个xi∈RDx_i \in \Bbb R^D是一个样本的DD维的特征向量,yi∈{1,...,K}y_i \in \lbrace 1,...,K \rbrace是其对应的类别标签,SkS_k表示类别为kk的支持集。
原型网络通过一个嵌入函数fϕ:RD→RMf_{\phi}: \Bbb R^D \to \Bbb R^M为每个类计算一个MM维的原型ck∈RMc_k \in \Bbb R^M,即将维度为DD的数据映射到MM维的embedding空间中,每个类的原型ckc_k是支持集中样本的embedding的均值向量:

给定一个距离函数d:RM×RM→[0,+∞)d: \Bbb R^M \times \Bbb R^M \to [0,+ \infty)计算query point xx和各个类原型之间的距离,原型网络通过将softmax作用在这些距离上,得到xx的类别概率分布:

训练过程就是通过SGD最小化目标函数J(ϕ)=−logpϕ(y=k∣x)J(\phi)=-log p_{\phi}(y=k|x),其中kk是训练样本的真实标签。训练时的episode为:从训练集中随机选择一个类子集,在每个类中选择一些样本作为支持集,从剩余的样本中再选出一些作为query point,伪代码如下:

在这个算法中,NN是训练集中样本的数量,KK是训练集中类别的数量,NC≤KN_C \leq K是每个episode中被选择的类别的数量,NSN_S是每个类的支持集中的样本数量,NQN_Q是每个类的query样本的数量,RANDOMSAMPLE(S,N)R_{ANDOM}S_{AMPLE}(S,N)表示从SS中随机选出NN个元素。
- 输入是训练集DD,DkD_k表示该数据集中所有样本的类别都为kk
- 最后要得到的是一个训练episode的损失JJ
- 首先从KK个类中随机选出NCN_C个类
- 然后对于NCN_C个类中的每一个类,选择NSN_S个样本作为支持集,再从剩下的样本中选出NQN_Q个作为query样本,最后计算每个支持集的类原型
- 将损失JJ初始化为0,对于每个query样本计算损失JJ为:

关于原型网络的两种解释
1. 将原型网络看作混合密度估计
原型网络中对于距离的度量属于Bregman散度,其中就包括平方欧氏距离和马氏距离,本文使用的就是平方欧氏距离。本文对于原型的计算可以看作是在支持集上的hard clustering,每个类都有一个簇,支持集中的每个样本都被分配到其对应的类簇。对于Bregman散度来说,如果一个cluster representative到该聚类中所有点的距离最小,那么它就是该聚类的平均(mean),因此原型的计算就是使用Bregman散度在支持集中计算出最佳的cluster representative,即找到最优的聚类中心。
2. 将原型网络解释为线性模型
当使用欧氏距离时,原型网络模型中的softmax相当于有特定参数的线性模型,代入d(z,z′)=∣∣z−z′∣∣2d(z,z{'})=||z-z{'}||^2并展开得到:

第一项对类别kk来说是常量,不会对softmax概率造成影响,将后两项写成线性模型为:

本文的结果证明,尽管可以等同于线性模型,但欧氏距离依然是一个有效的选择,作者猜想这是由于在嵌入函数中已经学到了所需的非线性元素,因此使用欧氏距离使得方法更加简单有效。
与Matching Networks的比较
原型网络和MatchingNet在few-shot方面是不同的,但两者在one-shot上是相同的 。MatchingNet在给定支持集的情况下产生了一个加权最近邻分类器,而原型网络使用平方欧氏距离产生了一个线性分类器。在MatchingNet中,提出了一系列对网络结构的扩展,包括解耦支持样本和query样本的嵌入函数、使用FCE等,这些或许可以和原型网络结合起来,但会增加参数,并且FCE使用双向LSTM对支持集进行任意排序,这些都使网络结构变得更复杂。本文证明使用简单的设计也能达到同样的性能。
在one-shot方面 ,当每个类的支持集中只有一个样本时,ck=xkc_k=x_k,即原型就是那个样本,此时原型网络和MatchingNet是可以等同的。
还有一个问题就是,在每个类的支持集中设置多个原型会不会更有用 。如果每个类的原型的个数是固定的并且大于1,那么需要一种划分机制来进一步对属于一个类的样本点进行聚类,这就需要另一个单独的划分阶段,这个阶段和权重的更新是分开的,而本文的方法,即只设置一个原型,只需要用简单的梯度下降方法去学习。
一些其他的设计思路
1. 距离度量
MatchingNet中使用的是余弦距离,虽然说原型网络和MatchingNet使用任何距离度量都是可行的,但作者发现,使用平方欧氏距离能使两种网络都得到提高。作者认为这是由于余弦距离并不属于Bregman散度,因此网络也并不能等同于混合密度估计。
2. episode的设置
以往的实验发现,在训练和测试是保持相同的episode设置往往会获得较好的结果,比如在测试时使用5-way 1-shot的方式,那么训练时的episode就应该设置为Nc=5N_c=5,NS=1N_S=1,其中NcN_c是从数据集中选择的类别的个数,NSN_S是每个类中被选为支持集样本的个数。然而本文发现,使用比测试时更高的NcN_c(way)对模型更有益,在实验中,根据保留的验证集对NcN_c进行调整。不过对于NSN_S来说,在训练和测试时shot的设置最好保持一致。
原型网络对于zero-shot所做的一些改动
zero-shot和few-shot还是有些不同的,few-shot为每个类生成一个支持集,而在zero-shot中,为每个类设置一个元数据向量vkv_k,vkv_k是根据每个类的属性描述、原始数据等生成的,也即每个类的原型是一个事先知道的元数据向量vkv_k,也可以通过学习来得到。也可以简单的修改原型网络使其能够处理zero-shot learning:简单的定义ck=gv(vk)c_k=g_v(v_k),为元数据向量的embedding。由于元数据向量和query point来自不同的输入域,作者发现可以固定原型embedding的长度为单位长度,而对query embedding不设限制。
