多模态text-image模型之ITC loss
发布时间
阅读量:
阅读量
最近在看多模态内容,记录一下文图模型中常用的损失函数。最先提出ITC loss的是论文ALBEF,下面是文章对该Loss的定义
假设有输入图片 I 经过image encoder之后变成{v_{cls}, v_1, …, v_N},输入文本 T 经过 text encoder 后变成{w_{cls}, w_1,…, w_N}
ITC Loss 的全称是 Image-Text Contrastive Loss ,为了在融合之前学习更好的unimodal表示,它学习s = g_v (v_{cls})^T g_w(w_{cls}),这里的 g_v和g_w函数是给cls token embedding降维的线性层。另一方面,文图对会进入一个momentum unimodal encoders(这个结构的作用是通过结合过去更新中积累的知识,帮助稳定和提高学习表示的质量),变成g′_v (v′_{cls}) 和g′_w(w′_{cls})
定义:
s(I, T) = g_v (v_{cls})^T g′_w(w′_{cls}) \\ s(T, I) = g_w(w_{cls})^Tg′_v (v′_{cls})
对于每个图像和文本,我们计算softmax归一化的图像到文本和文本到图像相似度为:

其中的\tau是可学习的参数。令onehot相似度的真实值是y^{i2t} (I) 和y^{t2i}(T),真值中负样本对的概率为0,正样本对的概率为1。
ITC loss为p和y的交叉熵:

ALBEF代码中ITC loss对应的主要代码:
sim_i2t = image_feat @ text_feat_all / self.temp
sim_t2i = text_feat @ image_feat_all / self.temp
# image_feat和text_feat分别是图片和文本特征,text_feat_all和image_feat_all是从momentum encoder中取出来的文本、图像特征
# self.temp = nn.Parameter(torch.ones([]) * config['temp']) ,引入一个可学习的参数,可以对计算的结果进行缩放,从而调整模型
loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
# F.log_softmax(sim_i2t, dim=1)对sim_i2t的每一行进行log_softmax计算
# sim_i2t和sim_i2t_targets的形状一样,sim_i2t_targets是真实值
# F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets是矩阵按元素相乘
loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean()
# loss_t2i中的操作同上
loss_ita = (loss_i2t+loss_t2i)/2 #求平均得到ITC Loss
之后在更新同一篇文章中的Image-Text Matching (ITM) loss。
全部评论 (0)
还没有任何评论哟~
