Advertisement

多模态text-image模型之ITM loss(blip)

阅读量:

主要代码:

复制代码
 # forward the positve image-text pair

    
 # 正向传播正面的图像文本对
    
 output_pos = self.text_encoder.bert(encoder_embeds=text_embeds, 
    
                                 attention_mask=text.attention_mask,
    
                                 encoder_hidden_states=image_embeds,
    
                                 encoder_attention_mask=image_atts,      
    
                                 return_dict=True,
    
                                 mode='fusion',
    
                                )            
    
 with torch.no_grad():
    
     bs = image.size(0)  # 获取批量大小          
    
     weights_i2t = F.softmax(sim_i2t[:, :bs], dim=1)  # 对image到text的相似度进行softmax,沿着第二个维度计算
    
     weights_t2i = F.softmax(sim_t2i[:, :bs], dim=1)  # 对text到image的相似度进行softmax,沿着第二个维度计算
    
  
    
     weights_i2t.fill_diagonal_(0)  # 将权重矩阵的对角线设为0
    
     weights_t2i.fill_diagonal_(0)  # 将权重矩阵的对角线设为0
    
  
    
 # select a negative image for each text
    
 # 为每个文本选择一个负面的图像
    
 image_embeds_neg = []    
    
 for b in range(bs):
    
     neg_idx = torch.multinomial(weights_t2i[b], 1).item()  # 根据权重选择负面图像的索引
    
     image_embeds_neg.append(image_embeds[neg_idx])  # 添加负面图像到列表
    
 image_embeds_neg = torch.stack(image_embeds_neg, dim=0)  # 将负面图像张量堆叠起来
    
  
    
 # select a negative text for each image
    
 # 为每张图像选择一个负面的文本
    
 text_embeds_neg = []
    
 text_atts_neg = []
    
 for b in range(bs):
    
     neg_idx = torch.multinomial(weights_i2t[b], 1).item()  # 根据权重选择负面文本的索引
    
     text_embeds_neg.append(text_embeds[neg_idx])  # 添加负面文本到列表
    
     text_atts_neg.append(text.attention_mask[neg_idx])  # 添加负面文本的注意力掩码到列表
    
 text_embeds_neg = torch.stack(text_embeds_neg, dim=0)  # 将负面文本张量堆叠起来
    
 text_atts_neg = torch.stack(text_atts_neg, dim=0)  # 将负面文本的注意力掩码张量堆叠起来
    
  
    
 text_embeds_all = torch.cat([text_embeds, text_embeds_neg], dim=0)  # 拼接所有的文本张量
    
 text_atts_all = torch.cat([text.attention_mask, text_atts_neg], dim=0)  # 拼接所有的文本的注意力掩码张量
    
  
    
 image_embeds_all = torch.cat([image_embeds_neg, image_embeds], dim=0)  # 拼接所有的图像张量
    
 image_atts_all = torch.cat([image_atts, image_atts], dim=0)  # 拼接所有的图像的注意力掩码张量
    
  
    
 output_neg = self.text_encoder.bert(encoder_embeds=text_embeds_all, 
    
                                 attention_mask=text_atts_all,
    
                                 encoder_hidden_states=image_embeds_all,
    
                                 encoder_attention_mask=image_atts_all,      
    
                                 return_dict=True,
    
                                 mode='fusion',
    
                                )                         
    
  
    
 vl_embeddings = torch.cat([output_pos.last_hidden_state[:, 0, :], output_neg.last_hidden_state[:, 0, :]], dim=0)  # 拼接正负样本的嵌入表示
    
 vl_output = self.itm_head(vl_embeddings)  # 输入到信息论训练头部            
    
  
    
 itm_labels = torch.cat([torch.ones(bs, dtype=torch.long), torch.zeros(2 * bs, dtype=torch.long)],  # 创建信息论训练标签
    
                    dim=0).to(image.device)  # 将标签转移到相同的设备上
    
 loss_itm = F.cross_entropy(vl_output, itm_labels)  # 计算信息论训练损失     
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-18/ZJnRm4sUOuKyEqDxle25WtI3LVaN.png)

参考:[多模态text-image模型之ITM loss-博客]( "多模态text-image模型之ITM loss-博客")

求Loss的代码:

复制代码
    loss_itm = F.cross_entropy(vl_output, itm_labels)

vl_output 是模型输出的分类得分,itm_labels 是每个样本的真实标签。

vl_output:模型输出的是经过训练头部(self.itm_head)的得分,这个头部是一个全连接层,用于将模型学到的特征映射到正面和负面类别的得分。

itm_labels:模型对应的标签,包含了每个样本的真实标签。torch.ones(bs, dtype=torch.long) 是正面样本的标签,设为 1,torch.zeros(2 * bs, dtype=torch.long) 是负面样本的标签,设为 0。然后,使用 torch.cat 函数将这些标签连接起来,形成一个完整的标签张量。

loss_itm:通过调用 F.cross_entropy 函数计算模型输出和真实标签之间的交叉熵损失。这个损失反映了模型预测和实际标签之间的差异,用于指导模型参数的更新,以便更好地区分正面和负面样本。

全部评论 (0)

还没有任何评论哟~