Segment Anything 模型结构分析
SAM
首先先来讲一讲SAM。有讲的不对的地方请指出,谢谢!

SAM工作采用新型的prompt工程结合预训练大模型的方法来进行图像分割,并以实现零样本学习(传统做法是基于预训练模型加上微调方法)
整个SAM架构可以分为三大模块:图像编码器模块、提示编码器模块以及遮罩解码器模块。随后会对各个组件进行详细说明
Image Encoder
其image encoder模块基于MAE预训练的ViT架构。原始图像经过等比缩放及边缘填充后统一调整为1024像素分辨率。通过使用kernel大小为16、stride为16的一系列卷积操作将图像划分为64×64×768维的空间特征向量。在这一过程中,将提取到的空间特征向量输入到transformer encoder模块中进行特征提取,在此过程中,输出的空间位置编码进一步压缩至嵌入维度为256
image encoder这一部分所需的计算资源和存储资源消耗显著,在META官方发布的演示视频中所展示的image embedding也是在云端服务器上运行。因此为了实现模型轻量化目标,我们需要对这一部分进行优化。
prompt encoder
在官方发布的META演示案例中,在指定一个点位的基础上实现语义分割过程的具体方法,请参考下图。

也可以框选一个区域,来进行语义分割,如下图所示。

除此之外,在论文中提到了text prompt功能。该功能在演示文中未展示,并认为其作用是为需要分割的目标区域提供一个描述 ,而SAM系统则根据该描述执行相应的分割操作。
在论文中将这三种prompts归类为稀疏类prompt(sparse prompt)。point and box, defined by their top-left and bottom-right points, utilize position embeddings based on transformer principles, which employ sine and cosine functions to encode relative positions and orderings.这些position embeddings还结合了可学习的位置分类嵌入(cls embeddings),从而形成了完整的嵌入表示;(这一部分内容建议参考代码详细了解细节)。
text prompt同样是稀疏类型的prompt, 但是显然无法使用pe来表示它. 在SAM中对应于text的部分使用的是CLIP架构中的文本编码器部分, 具体可以参考CLIP的相关资料.
另一个prompt被定义为mask。该mask通过卷积神经网络执行下采样操作后与image embedding执行逐元素运算(其作用相当于将两者进行数值叠加运算)。这属于一种复杂的数值处理过程
mask decoder
下图是论文中给出的mask decoder的结构

普遍认为大多数人与我有同感,在首次观察时让人感到不知所措。该现象由众多复杂的箭头构成,在该领域内的文献对该现象的描述较为缺乏。因此我们决定从左至右依次进行分析
其中包含图像嵌入(image embedding)与提示嵌入(prompt embedding)两个关键组件。值得注意的是,在输出token(output tokens)之前并没有特别提及这一过程。在图像嵌入模块(image embedding block)中,在多个自注意力层(multiple self-attention layers)之后引入了一个分类器令牌(class token)。而在SAM模型中,则是为了适应语义分割任务的需求,在生成多于一个分割掩码(multiple segmentation masks),如图所示。

针对point prompt而言,在论文研究中常用剪刀这一实例进行说明。将point设置在剪刀柄的位置上后,则分割的目标可能属于这三种情况之一:即全部区域、部分区域或子部分区域。如何确定最终的输出结果则与output tokens的选择密切相关。每个output mask对应一个唯一的output tokens配置,并且还引入了一个IoU prediction head来选择三个mask中最符合预期的那个(该IoU prediction head是一个可学习分支模块,在模型训练过程中基于ground truth进行优化)。
def forward(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
multimask_output: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Predict masks given image and prompt embeddings.
Arguments:
image_embeddings (torch.Tensor): the embeddings from the image encoder
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
multimask_output (bool): Whether to return multiple masks or a single
mask.
Returns:
torch.Tensor: batched predicted masks
torch.Tensor: batched predictions of mask quality
"""
masks, iou_pred = self.predict_masks(
image_embeddings=image_embeddings,
image_pe=image_pe,
sparse_prompt_embeddings=sparse_prompt_embeddings,
dense_prompt_embeddings=dense_prompt_embeddings,
)
# Select the correct mask or masks for output
if multimask_output:
mask_slice = slice(1, None)
else:
mask_slice = slice(0, 1)
masks = masks[:, mask_slice, :, :]
iou_pred = iou_pred[:, mask_slice]
# Prepare output
return masks, iou_pred
def predict_masks(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Predicts masks. See 'forward' for more details."""
# Concatenate output tokens
# output_tokens由IoU_token和mask_token组成,将他们进行拼接
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
#将原来2维的output_tokens空间化后按照sparse_prompt_embedding的维度进行扩展(sparse_prompt_embedding的维度是什么呢?)
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
#将扩展后的output_tokens和sparse_prompt_embeddings进行拼接
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
# Expand per-image data in batch direction to be per-mask
# 我的理解就是有几个tokens就把image_embedding扩充几次,这个tokens应该和prompt的类型有关
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
# 把image_embeddings和mask_embeddings加上
src = src + dense_prompt_embeddings
# 和ViT一样,加上learnable_position_encoding
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
b, c, h, w = src.shape
# Run the transformer
# token to image attention(先按下不表)
#这里的transformer整合了右下角的token to image attn.后面再说
hs, src = self.transformer(src, pos_src, tokens)
# 从hs中分离出两个token
iou_token_out = hs[:, 0, :]
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
# Upscale mask embeddings and predict masks using the mask tokens
# 梯形的那个上采样
src = src.transpose(1, 2).view(b, c, h, w)
upscaled_embedding = self.output_upscaling(src)
hyper_in_list: List[torch.Tensor] = []
for i in range(self.num_mask_tokens):
hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
hyper_in = torch.stack(hyper_in_list, dim=1)
b, c, h, w = upscaled_embedding.shape
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
# Generate mask quality predictions
iou_pred = self.iou_prediction_head(iou_token_out)
return masks, iou_pred
那么未作详细说明的内容我们下面就可以展开讲述,在代码实现中该架构不仅仅是一个标准框架的transformer模型;其中包含了一个更为复杂的架构:即在mask decoder模块中左侧被深色背景框出的部分以及右侧下方标记了Token-to-Image Attention的小区域;经过上述处理后随后这个变换体直接输出经过双层卷积处理后的特征图和分化的输出令牌
在论文中指出, mask decoder采用了TwoWayTransformer模型. 在代码实现方面, 我们可以通过segment-anything/segment_anything/build_sam.py路径定位到MaskDecoder类的详细代码. 如图所示.

从图中可以看出,transformer模块将被调用,并基于TwoWayTransformer类进行操作。其中深度参数设置为2时对应于图中的×2缩略形式。

根据该路径可以看出TwoWayTransformer的具体定义。在同一个文件夹中也可以看到TwoWayAttentionBlock的相关内容;此外,在transformer架构中提到了self-attention机制以及图像到文本注意力的位置设置。因此,在这项研究中所提出的TwoWayAttentionBlock完全是自注意力机制对应的代码实现
先来看整体的TwoWayTransformer的代码
class TwoWayTransformer(nn.Module):
def __init__(
self,
depth: int,
embedding_dim: int,
num_heads: int,
mlp_dim: int,
activation: Type[nn.Module] = nn.ReLU,
attention_downsample_rate: int = 2,
) -> None:
"""
A transformer decoder that attends to an input image using
queries whose positional embedding is supplied.
Args:
depth (int): number of layers in the transformer
embedding_dim (int): the channel dimension for the input embeddings
num_heads (int): the number of heads for multihead attention. Must
divide embedding_dim
mlp_dim (int): the channel dimension internal to the MLP block
activation (nn.Module): the activation to use in the MLP block
"""
super().__init__()
self.depth = depth
self.embedding_dim = embedding_dim
self.num_heads = num_heads
self.mlp_dim = mlp_dim
self.layers = nn.ModuleList()
# 这两个注释之间的内容就是实现框架的结构,depth为2,start
for i in range(depth):
self.layers.append(
# 深色框选部分
TwoWayAttentionBlock(
embedding_dim=embedding_dim,
num_heads=num_heads,
mlp_dim=mlp_dim,
activation=activation,
attention_downsample_rate=attention_downsample_rate,
skip_first_layer_pe=(i == 0),
)
)
# 显然这个就是右下角的token to image attn.
self.final_attn_token_to_image = Attention(
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
)
self.norm_final_attn = nn.LayerNorm(embedding_dim)
# 到此能说明结构走向 end
def forward(
# 这部分就是前面两个encoder得到的结果,对于point prompt和box prompt来说其实都是点,无非前者是一个点,后者是两个点
self,
image_embedding: Tensor,
image_pe: Tensor,
point_embedding: Tensor,
) -> Tuple[Tensor, Tensor]:
"""
Args:
image_embedding (torch.Tensor): image to attend to. Should be shape
B x embedding_dim x h x w for any h and w.
image_pe (torch.Tensor): the positional encoding to add to the image. Must
have the same shape as image_embedding.
point_embedding (torch.Tensor): the embedding to add to the query points.
Must have shape B x N_points x embedding_dim for any N_points.
Returns:
torch.Tensor: the processed point_embedding
torch.Tensor: the processed image_embedding
"""
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
bs, c, h, w = image_embedding.shape
# 显然,image embedding是四维的,这里是将第三维开始展开,将后面的维度转化为一维
# 那么具体到image embedding,第三维和第四维也就是high和width
# 然后把三个维从BxCxHxW调整为BxHWxC
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
image_pe = image_pe.flatten(2).permute(0, 2, 1)
# Prepare queries
queries = point_embedding
keys = image_embedding
# Apply transformer blocks and final layernorm
for layer in self.layers:
queries, keys = layer(
queries=queries,
keys=keys,
query_pe=point_embedding,
key_pe=image_pe,
)
# Apply the final attention layer from the points to the image
q = queries + point_embedding # 为啥要加两回?
k = keys + image_pe # 合理的,image_embedding直接加上position_encoding
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
queries = queries + attn_out
queries = self.norm_final_attn(queries) # 这里的norm_final_attn就是一个Layernorm
# 这里还要再看一下
return queries, keys
仔细分析整个架构后
class TwoWayAttentionBlock(nn.Module):
def __init__(
self,
embedding_dim: int,
num_heads: int,
mlp_dim: int = 2048,
activation: Type[nn.Module] = nn.ReLU,
attention_downsample_rate: int = 2,
skip_first_layer_pe: bool = False,
) -> None:
"""
A transformer block with four layers: (1) self-attention of sparse
inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
block on sparse inputs, and (4) cross attention of dense inputs to sparse
inputs.
Arguments:
embedding_dim (int): the channel dimension of the embeddings
num_heads (int): the number of heads in the attention layers
mlp_dim (int): the hidden dimension of the mlp block
activation (nn.Module): the activation of the mlp block
skip_first_layer_pe (bool): skip the PE on the first layer
"""
super().__init__()
self.self_attn = Attention(embedding_dim, num_heads)
self.norm1 = nn.LayerNorm(embedding_dim)
self.cross_attn_token_to_image = Attention(
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
)
self.norm2 = nn.LayerNorm(embedding_dim)
self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
self.norm3 = nn.LayerNorm(embedding_dim)
self.norm4 = nn.LayerNorm(embedding_dim)
self.cross_attn_image_to_token = Attention(
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
)
self.skip_first_layer_pe = skip_first_layer_pe
def forward(
self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
) -> Tuple[Tensor, Tensor]:
# Self attention block
if self.skip_first_layer_pe:
queries = self.self_attn(q=queries, k=queries, v=queries)
else:
q = queries + query_pe
attn_out = self.self_attn(q=q, k=q, v=queries)
queries = queries + attn_out
queries = self.norm1(queries)
# Cross attention block, tokens attending to image embedding
q = queries + query_pe
k = keys + key_pe
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
queries = queries + attn_out
queries = self.norm2(queries)
# MLP block
mlp_out = self.mlp(queries)
queries = queries + mlp_out
queries = self.norm3(queries)
# Cross attention block, image embedding attending to tokens
q = queries + query_pe
k = keys + key_pe
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
keys = keys + attn_out
keys = self.norm4(keys)
return queries, keys
至此,在 SAM 模型中基本上完成了主要构建工作;但要全面理解 SAM,则仍需关注其核心要素——数据的设计这一关键问题。在论文研究中对此得到了充分的阐述与讨论;这也是构建 SAM 模型的关键环节之一
