小样本学习论文--Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks
文章目录
-
-
- 一、前言
- 二、论文解读
-
- 1、概述
- 2 、主要内容
-
-
2.1 问题设定
-
2.2 与模型无关的元学习算法
-
三、代码解析
-
- 1、抽取数据
-
- 2、MAML
- 3、训练过程
-
一、前言
对于已经系统学习过深度学习的同学来说,其算法精度往往与数据量的多少密切相关。人类只需观察数次狗的形象就能掌握区分狗类的能力;相比之下,深度学习算法可能需要大量的样本才能达到类似的识别能力。然而,在许多领域中收集高质量的数据往往成本高昂。因此研究如何使机器能够仅需几次观察即可掌握特定物体分类任务就成为了当前小样本学习领域的研究热点。
二、论文解读
1、概述
我们注意一下文章的title,其中有三个关键字Model-Agnostic(与模型无关的)、Fast Adaptation(快速适应)、Deep Networks(深度网络)。这三个关键字告诉我们这篇文章提出的方法,(本人感觉更像是思想或者框架)可以应用在各种神经网络的模型,并且可以快速适应不同的任务。下面我们来一起了解这三个关键字在文章中具体含义。
先介绍一下一些小样本学习(Few Shot Learing)中的一个概念,这也是一开始困惑我的概念,即 N-way N-shot。N-way 的意思是N分类,N-shot是在学习的样本中,每个类只提供5个样本,比如说让你学习辨认一只猫,只有5张5的照片供你学习。这篇文章做了3个实验,分别是有监督的图片分类,一个强化学习的实验,还有一个回归的实验。这几个实验的代码放在了一起,有点复杂。在这篇博客中,只以有监督的图片分类为例来解析代码。
在5-way 5-shot的分类实验中,使用的数据集是miniImagenet,这个数据集中有100个类别的图片,每一个类别中有600张图片,大概是100类生活中常见的自然与生活中物品的集合。每张图片的大小是84x84的大小,被划分成了train(64)、test(20)、val(16)三个子集。
2 、主要内容
2.1 问题设定
小样本元学习的目标是训练一个网络,这个网络可以经过少量的迭代次数快速的适应到新的任务中。定义一个模型 f,使得对于输入的X,会产生a. 我们训练这个网络使得它可以适应不同的无限的任务。
f(x) = a
Task,在图片分类的这个实验中,可以被定义为下式,其中L()是损失函数,这个损失是指在测试集上的损失,会在下面详细叙述。q()是样本的分布。
T =\{ { L(x_1,a_1),q(x_1) }\}
在这个模型中,作者考虑了一个Tas的分布p(T),在k-shot的情境下,使用k个样本训练模型,让模型学习从p(T)中抽取的新t_i,这k个样本是从q_i中抽取的,然后产生t_i的L_{t_i}。在meta-learning的过程中,使用q_i中没有用过的新样本来测试。模型f通过在q_i新样本上的 test\quad error的变化来提升的。也就是说在在每个任务中测试样本上的error作为了meta-learning过程的train error。
2.2 与模型无关的元学习算法
可以说这个算法显得非常复杂,但一旦深入掌握,就不会像想象中那么难以捉摸。

然而,在实际应用中这一类卷积神经网络的表现可能会呈现出一定的策略差异性。我们来详细探讨一下该模型的算法流程。
然而,在实际应用中这一类卷积神经网络的表现可能会呈现出一定的策略差异性。我们来详细探讨一下该模型的算法流程。

其中符号p(T)代表的是任务分布,在实际应用中,并没有特意对这一部分进行特别设定,在模型训练过程中,在处理样本采样时就已经自然形成了相应的概率分布模式。具体来说,则是尚不清楚其背后的概率分布类型是什么。\alpha,\beta分别指的是在任务学习阶段采用梯度下降方法所使用的学习率参数以及在元学习阶段所采用的学习率参数。而\theta则表示在神经网络模型f中的权重参数集。
- 初始化参数,这个没什么好说的
- while:
- 抽取Task,就是形成可能由不同内别图片组成的数据集,在作者提供的代码中,设定一个抽取4个Task,作为meta-learning的一个batch。在5-way 5-shot的情境下,作者为一个task抽取了100张照片,也就是5x20,5个类别,每个类别20张图片。task之间的5个类别有可能由重复的类别,也有可能不一样,这个是随机的。
- 对于每个Task:
- ,采样数据,把数据分成两部分,在5-way 5-shot设定中,一个类别只能使用5个类别来学习,那么把这100张照片分成5x5的训练集,以及5x15的验证集.
- 计算使用训练集得到的Loss,在图片分类的实验中,使用的是交叉熵函数。
- 通过Loss来计算SGD
- 使用验证集在经过6,7步调整的权重下计算test error。6,7,8三个步骤在图片分类的实验中循环了5次。
- 使用4个Task中的test error(5次循环中的最后一次)的平均值作为meta-learning的损失函数,来进行SGD过程。
- end
通过一个图例来辅助讲解一下:

在单个任务中, 通过左侧训练集执行五次SGD的过程, 然后利用右侧测试集计算测试误差, 在元学习过程中, 将一批数据中四个子任务的所有测试误差取平均值作为损失函数进行优化. 这一过程结束后, 模型参数收敛至图中所示的位置P

那么,我们再使用这个模型或者测试这个模型的准确度怎么用呢?在博客的最前面,我们说把100类图片分成了3个子集,train中有64个类,用于上述的meta-learning。现在要将这个模型用在新的任务集具有16个类的test数据集上。仔细一想,训练好的模型并没有看见过test数据集中任何类啊。现在就是要说title中的Fast Adaptation的关键字了,在5-way 5-shot设定中,在测试的时候从test数据集中随机抽取5个类,每个类抽取N(>5)张照片,其中每个类抽取5张照片,用来微调模型中的参数,比如说在一个新任务下,把模型的参数调整至 \theta_{3}^*的位置,就是task做的事,即在新任务下只用5张照片来学习一下,用剩下的照片来计算精度。
至此模型结束。那么为什么说title中的其他两个关键字呢,Model-Agnostic(与模型无关的)是说,可以把task换成其他可以进行SGD过程的模型;Deep Networks(深度网络)可以适用于所有的深度学习模型。
三、代码解析
官方发布的代码集整合了三个实验项目,在此基础上,博主依据源代码逻辑对相关内容进行了重新编写。(仅涉及图片分类的部分)如有错误之处,请赐教。
1、抽取数据
def make_data_tensor(self,train):
if train:
folders = self.metatrain_character_folders
num_total_batches = 200000 # meta-learning过程有200000/4个batch
else:
folders = self.metaval_character_folders
num_total_batches = 600
all_filenames = []
print("生成文件")
# 从训练集中抽取5个类,每个类20个样本,这个过程重复200000次
for _ in range(num_total_batches):
sample_character_folders = random.sample(folders,self.num_classes)
random.shuffle(sample_character_folders)
labels_and_images = get_images(sample_character_folders,range(self.num_classes),nb_samples=self.num_sample_per_class,shuffle=False)
labels = [li[0] for li in labels_and_images]
filenames = [li[1] for li in labels_and_images]
all_filenames.extend(filenames)
print("生成文件结束!")
# 使用tensorflow的机制来读取抽取的照片
filename_queue = tf.train.string_input_producer(tf.convert_to_tensor(all_filenames),shuffle=False)
image_reader = tf.WholeFileReader()
_,image_file = image_reader.read(filename_queue)
image = tf.image.decode_jpeg(image_file, channels=3)
image.set_shape((self.img_size[0], self.img_size[1], 3))
image = tf.reshape(image, [self.dim_input])
image = tf.cast(image, tf.float32) / 255.0
num_preprocess_threads = 1
min_queue_examples = 256
examples_per_batch = self.num_classes * self.num_sample_per_class # 每个批次样本的数量 = 类的数量 x 每个类中样本的数量
batch_image_size = self.batch_size * examples_per_batch
images = tf.train.batch(
[image],
batch_size = batch_image_size,
num_threads=num_preprocess_threads,capacity=min_queue_examples + 3*batch_image_size)
all_image_batches,all_label_batches = [],[]
# batch_size = 4 4个task组成一个meta-learning的batch
for i in range(self.batch_size):
image_batch = images[i * examples_per_batch:(i + 1) * examples_per_batch]
label_batch = tf.convert_to_tensor(labels)
new_list, new_label_list = [], []
for k in range(self.num_sample_per_class):
class_idxs = tf.range(0,self.num_classes)
class_idxs = tf.random_shuffle(class_idxs)
true_idxs = class_idxs * self.num_sample_per_class + k
new_list.append(tf.gather(image_batch,true_idxs))
new_label_list.append(tf.gather(labels,true_idxs))
new_list = tf.concat(new_list,0)
new_label_list = tf.concat(new_label_list,0)
all_image_batches.append(new_list)
all_label_batches.append(new_label_list)
all_image_batches = tf.stack(all_image_batches)
all_label_batches = tf.stack(all_label_batches)
all_label_batches = tf.one_hot(all_label_batches,self.num_classes)
return all_image_batches,all_label_batches
2、MAML
def constract_model(self,input_tensors = None,prefix='metatrain'):
if input_tensors is None:
self.inputa = tf.placeholder(tf.float32) # 从batch_x中抽取的前一部分
self.inputb = tf.placeholder(tf.float32) # 从batch_x中抽取的后一部分
self.labela = tf.placeholder(tf.float32)
self.labelb = tf.placeholder(tf.float32)
else:
self.inputa = input_tensors['inputa']
self.inputb = input_tensors['inputb']
self.labela = input_tensors['labela']
self.labelb = input_tensors['labelb']
with tf.variable_scope('model', reuse=None) as training_scope:
if 'weights' in dir(self):
training_scope.reuse_variables()
weights = self.weights
else:
# Define the weights
self.weights = weights = self.constract_weights()
lossesa,outputas,lossesb,outputbs = [], [], [], []
accuraciesa,accuraciesb = [], []
num_updates= max(self.test_num_updates,self.num_updates)
outputbs = [[]] * self.num_updates
lossesb = [[]] * self.num_updates
accuraciesb = [[]] * self.num_updates
def task_metalearn(inp, reuse = True):
inputa,inputb,labela,labelb = inp
task_outputbs,task_lossesb = [] ,[]
task_accuraciesb = []
task_outputa = self.forward(inputa,weights,reuse=reuse) #前向传播
task_lossa = self.loss_func(task_outputa,labela) #计算损失
grads = tf.gradients(task_lossa,list(weights.values()))
gradinents = dict(zip(weights.keys(),grads))
fast_weights = dict(zip(weights.keys(),[weights[key] - self.update_lr * gradinents[key] for key in weights.keys()])) #更新一次权重
output = self.forward(inputb,fast_weights,reuse=True)
task_outputbs.append(output)
task_lossesb.append(self.loss_func(output,labelb))
for j in range(num_updates - 1): # num_updates = 5
loss = self.loss_func(self.forward(inputa, fast_weights, reuse=True), labela)
grads = tf.gradients(loss, list(fast_weights.values()))
gradients = dict(zip(fast_weights.keys(), grads))
fast_weights = dict(zip(fast_weights.keys(),
[fast_weights[key] - self.update_lr * gradients[key] for key in
fast_weights.keys()]))
print("根据样本更新权重!",j+1)
output = self.forward(inputb, fast_weights, reuse=True)
task_outputbs.append(output)
task_lossesb.append(self.loss_func(output, labelb))
print("计算验证集上的损失!",j+1)
task_output = [task_outputa,task_outputbs,task_lossa,task_lossesb]
task_accuracya = tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(task_outputa), 1),
tf.argmax(labela, 1)) # 计算inputa的准确率
# 计算每一步更新中inputp的准确度
for j in range(num_updates):
task_accuraciesb.append(
tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(task_outputbs[j]), 1), tf.argmax(labelb, 1)))
task_output.extend([task_accuracya, task_accuraciesb]) # 保存在结果中并返回
return task_output
if self.norm is not None:
unused = task_metalearn((self.inputa[0], self.inputb[0], self.labela[0], self.labelb[0]), False)
out_dtype = [tf.float32,[tf.float32] * self.num_updates,tf.float32,[tf.float32] * num_updates,tf.float32,[tf.float32] * num_updates]
result = tf.map_fn(task_metalearn,elems=(self.inputa,self.inputb,self.labela,self.labelb),dtype=out_dtype,parallel_iterations=self.meta_batch_size)
outputas,outputbs,lossesa,lossesb,accuraciesa,accuraciesb = result
if 'train' in prefix:
self.total_loss1 = total_loss1 = tf.reduce_sum(lossesa) / tf.to_float(self.meta_batch_size)
self.total_loss2 = total_loss2 = [tf.reduce_sum(lossesb[j]) / tf.to_float(self.meta_batch_size) for j in range(num_updates)]
self.outputas,self.outputbs = outputas,outputbs
self.total_accuracy1 = total_accuracy1 = tf.reduce_sum(accuraciesa) / tf.to_float(self.meta_batch_size)
self.total_accuracies2 = total_accuracies2 = [
tf.reduce_sum(accuraciesb[j]) / tf.to_float(self.meta_batch_size) for j in range(num_updates)]
self.pretrain_op = tf.train.AdamOptimizer(self.meta_lr).minimize(total_loss1) # 这一步用于预训练
optimizer = tf.train.AdamOptimizer(self.meta_lr)
# 使用b部分中损失来计算梯度
self.gvs = gvs = optimizer.compute_gradients(self.total_loss2[self.num_updates - 1])
gvs = [(tf.clip_by_value(grad, -10, 10), var) for grad, var in gvs]
self.metatrain_op = optimizer.apply_gradients(gvs)
else:
self.meta_total_loss1 = total_loss1 =tf.reduce_sum(lossesa) / tf.to_float(self.meta_batch_size)
self.metaval_total_losses2 = total_losses2 = [tf.reduce_sum(lossesb[j]) / tf.to_float(self.meta_batch_size) for j in range(num_updates)]
self.metaval_total_accuracy1 = total_accuracy1 = tf.reduce_sum(accuraciesa) / tf.to_float(
self.meta_batch_size)
self.metaval_total_accuracies2 = total_accuracies2 = [
tf.reduce_sum(accuraciesb[j]) / tf.to_float(self.meta_batch_size) for j in range(num_updates)]
# 简要保存结果
tf.summary.scalar(prefix + 'Pre-update loss', total_loss1)
tf.summary.scalar(prefix + 'Pre-update accuracy', total_accuracy1)
for j in range(num_updates):
tf.summary.scalar(prefix + 'Post-update loss, step ' + str(j + 1), self.total_loss2[j])
tf.summary.scalar(prefix + 'Post-update accuracy, step ' + str(j + 1), total_accuracies2[j])
def constract_conv_weights(self):
weights = {}
dtype = tf.float32
conv_initializer = tf.contrib.layers.xavier_initializer_conv2d(dtype=dtype)
fc_initializer = tf.contrib.layers.xavier_initializer(dtype=dtype)
k = 3
weights['conv1'] = tf.get_variable('conv1',[k,k,self.channels,self.dim_hidden],initializer=conv_initializer, dtype=dtype)
weights['b1'] = tf.Variable(tf.zeros([self.dim_hidden]))
weights['conv2'] = tf.get_variable('conv2',[k,k,self.dim_hidden,self.dim_hidden],initializer=conv_initializer, dtype=dtype)
weights['b2'] = tf.Variable(tf.zeros([self.dim_hidden]))
weights['conv3'] = tf.get_variable('conv3',[k,k,self.dim_hidden,self.dim_hidden],initializer=conv_initializer, dtype=dtype)
weights['b3'] = tf.Variable(tf.zeros([self.dim_hidden]))
weights['conv4'] = tf.get_variable('conv4',[k,k,self.dim_hidden,self.dim_hidden],initializer=conv_initializer, dtype=dtype)
weights['b4'] = tf.Variable(tf.zeros([self.dim_hidden]))
weights['w5'] = tf.get_variable('w5',[800,self.dim_output],initializer=fc_initializer) # TODO 800?
weights['b5'] = tf.Variable(tf.zeros([self.dim_output]),name='b5')
return weights
3、训练过程
def train(model,saver,sess,exp_string,data_gen,resume_itr=0):
SUMMARY_INTERVAL = 100
SAVE_INTERVAL = 1000
PRINT_INTERVAL = 1000
TEST_PRINT_INTERVAL = PRINT_INTERVAL
if log:
train_writer = tf.summary.FileWriter(logdir + '/' + exp_string,sess.graph)
print('Done initialization,starting training')
prelosses,postlosses = [],[]
pre_total_time = time.time()
for itr in range(0,metatrain_iterations):
feed_dict = {}
input_tensors = [model.metatrain_op]
if(itr % SUMMARY_INTERVAL == 0 or itr % PRINT_INTERVAL == 0):
input_tensors.extend([model.summ_op,model.total_loss1,model.total_loss2[num_updates - 1],model.total_accuracy1, model.total_accuracies2[num_updates-1]])
pre_time = time.time()
result = sess.run(input_tensors,feed_dict)
pos_time = time.time()
print("当前运行{}代,本次用时:{}分,总用时:{}分".format((pos_time-pre_time) / 3600,(pos_time - pre_total_time) /3600))
if itr % SUMMARY_INTERVAL == 0:
prelosses.append(result[-2])
if log:
train_writer.add_summary(result[1],itr)
postlosses.append(result[-1])
if (itr!=0) and itr % PRINT_INTERVAL == 0:
print_str = 'Iteration ' + str(itr)
print_str += ':' + str(np.mean(prelosses)) + ',' + str(np.mean(postlosses))
print(print_str)
prelosses,postlosses = [],[]
if (itr != 0) and itr % SAVE_INTERVAL == 0:
saver.save(sess,logdir+"/"+exp_string+"/model"+str(itr))
if (itr!=0) and itr % TEST_PRINT_INTERVAL == 0:
feed_dict = {}
input_tensors = [model.metaval_total_accuracy1, model.metaval_total_accuracies2[num_updates-1], model.summ_op]
result = sess.run(input_tensors,feed_dict)
print("Validation result: " + str(result[0]) + ',' + str(result[1]))
saver.save(sess,logdir+"/"+exp_string+'/model'+str(itr))
