Advertisement

prototypical networks for few-shot learning论文

阅读量:

文章目录

  • 摘要
  • 引言
  • Prototypical network
    • Notaion
    • model
    • zero-shot Learning

Experimental studies have been conducted on Omniglot low-shot classification, as well as miniImageNet low-shot classification and CUB zero shot classification tasks.

  • 参考

论文:基于 prototypical 网络的少样本学习架构
地址:https://arxiv.org/abs/1703.05175v2
代码:https://github.com/jakesnell/prototypical-networks

摘要

针对小样本分类问题, 作者提出了一种基于原型的设计方案. 该模型作为分类器, 在训练集中未曾见过的新类别仅依赖少量数据即可实现良好识别. 原型网络通过构建度量空间来实现分类任务, 其核心机制在于计算输入样例与梅格雷(梅格勒) prototype表示之间的距离, 进而完成类别判别. 相较于当时盛行的小样本学习方法, 原型网络展现出一种更为简洁的归纳偏差特征, 在数据资源有限的情况下往往能够获得令人满意的性能表现. 论文中深入分析表明, 一些看似微不足道的设计决策, 实际上能在涉及复杂结构选择以及元学习(meta-learning)等方面带来实质性的性能提升. 进一步研究将该模型扩展至零样本学习场景, 并在CUB数据集上实现了最优性能.

引言

论文中探讨并解决了小样本学习中的过拟合问题。该研究提出了一种称为"原型网络"的方法,在支持集(support set)中通过神经网络将输入数据映射至一个度量空间,并用类原型(c_k)来表示每一种类别。在分类过程中,则是将待分类的数据点同样映射至该度量空间得到结果x后,并计算其与类原型c_k之间的距离,在此基础之上判断其归属关系(如图所示)。

在这里插入图片描述

Prototypical network

Notaion

支持集中共包含N个带标签的数据样本。数学表达式如下所示:S = \{(\mathbf{x}_1,y_1),\dots,(\mathbf{x}_N,y_N)\}其中每个特征向量\mathbf{x}_i\in\mathbb{R}^D属于D维空间。每个样本的类别标记由变量 y_i ∈ {1,…,K} 表示,并取值于集合 {1,…,K}`中。\对于类别k来说,子集 S_k = { (\mathbf{x}, y) | y=k } $ 包含了所有标记为k的带标签数据样本

model

计算类典型实例c_k属于实数空间\mathbb{R}^M
嵌入函数定义为f_\phi: \mathbb{R}^D \rightarrow \mathbb{R}^M
其中所采用的是神经网络模型。

在这里插入图片描述

用softmax分类:

在这里插入图片描述

优化是用过SGD最小化J(\phi) = -log p_\phi(y=k|\mathbf{x})

在这里插入图片描述

算法主要由两部分构成:首先,在episode中的每一类中计算对应的原型c_k是通过将该类所有数据向量化后取平均值得到的;其次,在query set的基础上进一步优化了整个算法流程。

zero-shot Learning

在零样本学习中与少样本学习相异的是其对应的\mathbf{v}_k并非基于训练集中的支持样例生成反而是通过综合考虑各类属性特性和原始数据等因素来生成这一过程提供了更加灵活的信息获取机制这些信息都可以通过分析各类属性特性和原始数据来进行提取和确定同样能够灵活地转换为适应零样本学习的情境为了便于构建我们可以定义\mathbf{c}_k = g_\theta \mathbf{v}(k)作为一个对应的meta-data向量

Experiments

Omniglot Few-shot Classification

Omniglot数据集包括了50种类型的手写字符共计1,623类。每一类字符仅提供20个样本,并且每个样本均由不同的人绘制而成。其分辨率设置为1\times 1像素吗?抱歉,在这里可能有些混淆,请重新确认参数设置的具体数值以确保图像质量达到最佳效果?或者更准确地说,在这种情况下图像的分辨率被设定为1\times 1像素?不过根据之前的描述似乎不太合理,请再审阅一下参数设定的具体内容以确保一致性与准确性?另外关于 episode 的数量与分类数目之间的关系也需要进一步明确以避免误解

在这里插入图片描述

miniImageNet Few-shot Classification

MinilmageNet数据集涵盖100个类别,每个类别包含600个样本数据.其中64个类别用于构建训练集的数据集合,16个类别用于构建验证集合的数据集合以及2O个类别用于构建测试集合的数据集合.该文分别采用3O-way的episode对I-shot分类任务以及2O-way的episode对5-shot分类任务来训练模型参数.在模型训练与测试过程中确保每次查询中的shot数目相同,并且每个查询点选取了各分类器下各I5个特征向量作为参考点

在这里插入图片描述

该实验分别采用了1-shot与5-shot两种配置,在5-way与20-way两种分类方式下进行了 episode 训练。实验结果表明,在 training episode 中采用更高类别的分类方式会带来一定程度上的提升效果。这一现象的根本原因在于:通过提高way的数量能够显著提升模型在复杂场景下的适应能力,并促使模型在嵌入空间实现更加细致的决策划分。

在这里插入图片描述

CUB Zero-shot Classification

CUB数据集划分为训练集、验证集和测试集三个部分。其中训练集包含100个互不重叠的任务类别(task category),验证集和测试集各包含50个互不重叠的任务类别。在312维的空间中(spatial domain),模型通过编码器模块提取鸟类的各种特征信息(feature information),包括其种类(category)、颜色(color)以及羽毛(feather)等基本属性信息(basic attribute information)。在每个训练周期(training episode)中涉及50个不同的鸟类分类任务(classification task),每个分类任务均分配10个不同的查询样本点(query points)。

在这里插入图片描述

参考

本节主要探讨—— Prototypical Networks 在小样本学习中的应用

本节主要探讨—— Prototypical Networks 在小样本学习中的应用

全部评论 (0)

还没有任何评论哟~