# How does the Segment-Anything Model’s (SAM’s) decoder work?

A deep dive into how the Segment-Anything model’s decoding procedure, with a focus on how its self-attention and cross-attention mechanism works.

The Segment-Anything (SAM) model is a 2D interactive segmentation model, or guided model. SAM requires user prompts to segment an image. These prompts tell the model where to segment. The inputs to the SAM model are a 2D image and a set of prompts. Users prompts tell the model where to focus. The output of the model is a set of segmentation masks at different levels and a confidence score associated with each mask.

A segmentation mask is a 2D binary array with the same size as the input image. In this 2D array, an entry at location *(x, y)* has a value 1 if the model thinks that the pixel at location *(x, y)* belongs to the segmented area. Otherwise, the entry is 0. Those confidence scores indicate model’s belief on the quality of each segmentation, higher score means higher quality.

The network architecture of SAM consists of an encoder and a decoder:

- The encoder takes in the image and user prompt inputs to produce image embedding, image positional embedding and user prompt embeddings.
- The decoder takes in the various embeddings to produce segmentation masks and confidence scores

This article focuses on how SAM’s decoder works. I will write another article about the encoder.

# SAM’s inputs and outputs

Together with the input image to segment, SAM also requires user prompts, and it support the following kinds of prompts.

## Different kinds of input user prompts

- Mouse clicks. A mouse click can be a positive click that tells the model to include the clicked location in the produced segmentation mask. It can also be a negative click telling the model to avoid the clicked location. SAM accept multiple clicks either positive or negative to the model.
- Bounding boxes. Bounding boxes are always positive signals, telling the model that the produced segmentation mask should be inside the boxed area. There are no negative bounding boxes. SAM accepts multiple bounding boxes.
- Dense masks. Dense mask is in the form of a 2D binary array with the same size as the input image. Entries with value one in this 2D array tell the model which pixels should be predicted as masks. The dense mask prompt is always a positive signal. SAM accepts a single dense mask prompt.

Mouse click and the bounding box clicks are called *sparse prompts*.

## Multiple levels of output masks and confidence scores

SAM’s decoder produces three levels of produced segmentation masks handles ambiguity in the given prompts.

For example, given a mouse click prompt, the following figure demonstrates the three levels of predicted segmentation masks. In this figure, the green star represents the location of the mouse click. A blue patch represents the predicted segmentation mask.

Every level of segmentation has a predicted confidence score, called iou score. “iou” means “intersection over union”. It is the default metric to measure performance of segmentation models. In here, the notion of intersection over union is used loosely, and not mathematically, so please just understand the score as a rough indication of prediction quality from model’s perspective.

## Three predicted masks or four?

How many levels of segmentation mask does the SAM model produce? The model can be configured, using the* multimask_output *config** , **to output 3 levels of masks or just one level. However, the underlying model architecture handles these two cases together by always producing four masks as a mask array (or equivalently, tensor) of four elements and 4 confidence scores, also in an array. If SAM is configured to output a single mask, then SAM returns the first mask from the mask array, otherwise it returns the bottom three masks from the mask array. Confidence scores are handled in the same way.

To stay closer to SAM’s code, in this article, I will describe that SAM always returns four masks. This way, the tensor shapes in my description match their shapes in the code.

# The decoding procedure in code and in charts

Now I will dive into code and network architecture charts to understand how SAM’s decoding procedure works.

Sometimes I need to adapt the code to describe them more easily. For example, if the same variable name is used to receive multiple values at different lines, which is a bad programming style that you should avoid, I need to introduce a new name at each line. So please understand the code snippets I show in this article are not the same as the original SAM’s implementation, but they are close to the origin.

The SamPredictor.predict_torch method is a good starting place to start our understanding of the SAM decoder:

The core of the method consists of three steps:

- Generate embeddings for user prompts, at line 222~226.
- Predict low resolution segmentation masks at different levels and the confidence level for each level, at line 229~235.
- Interpolate low resolution segmentation masks to the size of the original image, line 237~238.

## Important tensors in the decoding procedure

The following flowchart highlights the key tensors and their shapes during the encoding and decoding process. The flowchart also lists important tensor shapes because the shapes greatly help me understand neural network architectures. **Note I omitted the batch dimensions in this article to shorten shape notations.**

The raw image is a natural image of arbitrary height *H* and width *W*, so its tensor has 3 channels RGB; its shape is 3*×H×W.*

Before it reaches the network, the arbitrary sized input raw image is resized to a fixed size 1024×1024, while keeping the channel unchanged.

**Image embedding
**The image embedder ImageEncoderViT class produces an embedding for the resized image. The image embedding, self.features in the above snippets, is of shape feature_dim × embed_height × embed_width, which is 256×64×64*.* The image embedding has low resolution, only 64×64*. *Each feature vector at this low resolution has 256 channels. So a 16×16 patch of pixels in the resized natural image space contributes to a 256-long feature vector inside this 64×64 embedding space.

**Image positional embedding**
The PromptEncoder class also produces an image positional embedding of shape 256×64×64, same shape as the image embedding because image positional embedding will be element-wise added into the image embedding.

**Dense prompt embedding
**Given user prompts of different kinds, the PromptEncoder produces dense mask embeddings and sparse embeddings. The dense mask embedding has shape 256×64×64, same shape as the image embedding. This is because the dense prompt embedding will also be element-wise added into the image embedding.

**Sparse prompt embedding
**The sparse embeddings has shape *T*×256, with *T* being the number of tokens and 256 is the feature vector length. A token is a tensor to represent a sparse prompt or part of a sparse prompt (in the case of bounding boxes). The number of user sparse prompts decides the value of *T*:

- A guidance click contributes a single token.
- A bounding box contributes two tokens, one for the the upper-left corner and the other the bottom-right corner.

**How encoding works?
**This article doesn’t explain how SAM’s encoder produces the above embeddings via the encoding procedure. I will write another article for it.

**The predicted low resolution masks and confidence scores
**Given the four inputs, image embedding, image positional embedding, dense mask embedding and sparse prompt embedding, the MaskDecoder class produces low resolution segmentation masks at 4 levels and 4 confidence scores.

The masks has shape level_of_masks × low_height × low_width, which is 4×256×256. The confidence scores are four float numbers, with shape 4×1.

# How does the MaskDecoder work

The best place to understand how SAM produces segmentation masks is the MaskDecoder.predict_masks method, shown below.

Note in the above listing, I changed the variable src from line 132 on to a new name *src2* to make it easier distinguish between the image embedding (called *src*) before the transformer call at line 132 and the the attended image embedding (called *src2*) after the transformer call.

The following two flowcharts make it easier to under the code.

The first half (line 126 to 132) prepares arguments to the self.transformer method call at line 132 and then make the transformer call at line 132. The transformer call does most of the magic and produces the attended tokens embeddings and the attended image embedding.

The second half manipulates the attended token embeddings and the attended image embedding to produce low resolution segmentation masks and confidence scores. These manipulations are simple, such as MLP projections, matrix multiplication, and tensor upsampling.

## The input tokens tensor in the first half

The input tokens tensor is a concatenation of three parts, with their fixed order: first the iou_token, then the mask_tokens and then the sparse_prompt_embeddings tokens. The word “input” means this tensor is an input to the Transformer call.

**The iou_token (in the blue box)**

The input iou_token tensor comes from PyTorch’s Embedding class. It has fixed shape 1×256. It contains trainable parameters. You can roughly understand that input iou_token serves as inputs to the transformer call to produce the attended iou_token (in the green box). The attended iou_token will finally produces the confidence scores.

“roughly” means that the above understanding is not fully correct in the sense that the input iou_token is not the only information source used to produce the attended iou_token. The transformer uses attention that blends information from all prompt embeddings, image embedding and image positional embedding to produce the attended iou_token tensor.

Note that the input iou_token is not connected to any user input, namely, the raw input image and the user prompts. You may ask what are the trainable parameters in iou_token used for? They are used to make sure that the input and the output tensors of the attention transformer have the same shape. And values of those trainable parameters are adjusted during parameter learning via stochastic gradient decent. You may ask, do the parameters inside the iou_token have to be trainable? Can we just set them to all zeros? Maybe, but the authors of the SAM model decided that they are trainable possibly this way results in better model performance. The same argument goes for the mask_tokens, explained in the next section.

**The mask_tokens (in the blue box)**

The input mask_tokens has fixed shape 4×256, coming from PyTorch’s Embedding class; it contains 4 vectors with each vector being 256 entries long. “Roughly” understand that these 4 vectors serve as input for the transformer call to produce the attended mask_tokens (in the green box). The attended mask_tokens will finally produces segmentation mask prediction heads for the four levels of masks.

**The sparse_prompt_embeddings (in the blue box)**

The sparse_prompt_embedddings input tensor has a varying size; it shape is T×256 with T being the number of sparse prompts. So more sparse prompts, such as more guidance clicks, more bounding boxes, means a bigger T. Note the sparse_prompt_embeddings directly comes from the PromptEncoder and simply concatenated into the tokens tensor.

With these three parts, the input tokens tensor has shape (5+T)×256, with T being the number of sparse prompts, so the tokens tensor has a varying shape.

Like the iou_token, mask_tokens and sparse_prompt_embeddings also contain trainable parameters.

## The src image embedding with dense prompt information

At line 126~127, the image_embeddings tensor is element-wise summed with the dense_prompt_embeddings tensor to form the src tensor. Both tensors have the shape 256×64×64, so the src tensor also has shape 256×64×64, providing image information blended with dense mask prompt information at a low resolution 64×64 (height×width). Please ignore the call of repeat_interleave at line 126; it doesn’t affect your understanding of the decoding procedure.

The image_embeddings tensor comes from the ImageEncoderViT encoder, and then dense_prompt_embeddings tensor comes from PromptEncoder.

## The pos_src image positional embedding

The pos_src tensor is the image positional encoding, with shape 256×64×64. It comes from PromptEncoder.

## The transformer call

The transformer call at line 132 takes the above three tensors as inputs and produces two new tensors, the attended tokens embedding tensor, called *hs* in the above listing, and the attended image embedding tensor, called *src2*.

The attention mechanism inside the transformer call blends in information from all its three inputs to produce these two attended output tensors:

- the attended token embedding tensor
*hs* - the attended image embedding tensor
*src2*.

**The attended tokens embedding hs**

The attended tokens embedding *hs* has the same structure and shape as the input tokens embedding tensor. Structurally, the *hs* tensor also consists of three attended parts— the iou_token, the mask_tokens and the sparse_prompt_embeddings. And its shape is also (5+T)×256.

From the *hs* tensor, only the *iou_token* part and the *mask_tokens* part are used for mask generation and the *sparse_prompt_embeddings* part is ignored:

- The
*iou_token*part is used to produce confidence scores. - The
*mask_tokens*part is used to produce segmentation head.

Why is there a output sparse_prompt_embeddings part even if it is ignored? This is because the transformer operation is an attention operation, it is designed to have the same input and output shape.

**The output attended image embedding tensor src2**

The *src2* tensor has shape 256×64×64, same as the input image embedding tensor.

## Producing final segmentation masks and their confidence scores

**Producing segmentation masks in the second half flowchart**

To produce the four segmentation masks, SAM follows the usual segmentation head matrix multiplied with attended image embedding approach, like most segmentation models.

The segmentation head is from the attended mask_tokens tensor, whose shape is 4×256. This tensor has 4 vectors with each vector consisting of 256 entries. Each vector is a segmentation head for a level of mask.

This attended mask_tokens tensor is projected into the *hyper_in* tensor with shape 4×32 by the *output_hypernetworks_mlps* network.

On the other branch in the second half flowchart, the attended image imbedding tensor *src2* of shape 256×64×64 is upscaled to 32×256×256, that is, shorter feature dimension (256 to 32) but larger resolution (64×64 to 256×256), then reshaped to 32×65536. This 32×65536 tensor is will be multiplied with the segmentation head tensor to produce segmentation masks.

Now we can matrix multiple the segmentation head tensor of shape 4×32 with the image embedding tensor of shape 32×65536 to produce the segmentation masks of shape 4×65536, which is reshaped back to 4 ×256×256. This is the final segmentation masks — 4 masks at different levels, each mask has the low resolution of 256×256.

Earlier in this article, we know that these masks of low resolution 256×256 is interpolated to the original input image dimension of H×W before returning back to the user.

**Producing confidence scores in the second half flowchart**

To produce confidence scores for the four levels of segmentation masks, a MLP network *iou_prediction_head* projects the attended *iou_token* with shape 1×256 to a new tensor of shape 1×4. Each entry of this 1×4 tensor is a float number, interpreted as the confidence score for the segmentation mask at the corresponding level.

## Trainable parameters

In the second half of the flowcharts, trainable parameters exist in those components that manipulates tensors: *iou_prediction_head*, *output_hypernetwork_mlps* and *output_upscaling*.

# The attention mechanism

Now it is time to understand how the attention mechanism inside the transformer call. The transformer call site is in MaskDecoder.predict_mask line 132:

It invokes the TwoWayTransformer.forward method, shown below. I renamed some variables, for example, *queries* to *queries2* and *queries3* to ease the discussion.

The purpose of this method is to use attention to blend information from the input arguments:

- the image embedding
*image_embedding*with image positional encoding already added in (see the flowchart for first half of MaskDecoder.predict_masks), - the image positional embedding
*image_pe,* - the sparse prompt embedding
*point_embedding*,

The call returns the attended sparse prompt embedding as *point_embed_attn4* and the attended image embedding as *image_embed_attn2* at line 150.

From the call site point of view:

The returned *hs* tensor receives the attended sparse prompt embedding* point_embd_attn4*. The *src2* tensor receives the attended image embedding* image_embed_attn2*. As explained before, *hs* and and *src2* are used to produce the final segmentation masks and confidence scores.

## TwoWayTransformer.forward

Let’s now go over the code in my revised version of the TwoWayTransformer.forward method with the help of its flowchart.

At line 129~130, the method first reshaped the image embedding tensor *image_embedding* from shape 256×64×64 to 4096×256. In my version of the code, I called the result *image_embedding1* at line 192 and then at line 134, the *keys0* variable is assigned to *image_embedding1*. It does the same reshaping for the image positional embedding tensor *image_pe*, resulting in *image_pe1* tensor, with shape 4096×256.

At line 135~137, the *point_embedding* tensor and the *image_embedding1* tensor goes through a TwoWayAttentionBlock component *self.layers[0]* to produce attended version of them with unchanged shapes:

*point_embed_attn1*of shape (5+T)×256*image_embed_attn1*of shape 4096×256.

Then at line 139~141, *point_embed_attn1 *and* image_embed_attn1 *go through the same TwoWayAttentionBlock to produce yet another attended versions:

*point_embed_attn2*of shape (5+T)×256*image_embed_attn2*of shape 4096×256.

The attended *point_embed_attn2* tensor is further manipulated to produce the final *point_embed_attn4*,

The *image_embed_attn2* tensor is returned to the caller unchanged, but the *point_embed_attn2* tensor is further attended at line 144 to produce the *point_embed_attn4 *tensor before retuning. We can ignore these further manipulations because line 144 uses the same TwoWayAttentionBlock, which I will explain now.

Trainable parameters live in the layernorm at line 148. Other trainable parameters live in the TwoWayAttentionBlock.

# TwoWayAttentionBlock.forward

Now it is time to finally dive into the TwoWayAttentionBlock.forward method from the point of view of *self.layers[0]*. Below I showed my version of the code with more meaningful variable names. This method accepts the following inputs:

- sparse prompt embedding
*point_embedding*, shaped (5+T)×256 - image embedding
*image_embedding,*shaped 4096×256 - image positional embedding
*image_pe1,*shaped 4096×256.

It returns the following two new tensors:

- the attended sparse prompt embedding
*point_embed7_attn,*shaped - the attended image embedding
*image_embed2_attn*, shaped 4096×256.

The methods applies three attention operations on its inputs:

- self attention from sparse prompt embedding to sparse prompt embedding. That is, from
*point_embedding*to*point_embedding*at line 232 to produce the point_embed1_attn tensor of shape (5+T)×256. - cross attention from sparse prompt embedding (after residual addition at line 240) to image embedding (after summed with image positional embedding at line 241) at line 239 to produce the
*point_embed3_attn*tensor of shape (5+T)×256. - cross attention from image embedding (after added with image positional embedding at 253) to sparse prompt embedding (after residual addition at line 254) at line 252 to produce the image_embed1_attn tensor of shape 4096×256.

The two attentions from point 2 and point 3 gives the name of the class — TwoWayAttentionBlock.

In the above listing:

- all variables with the “
*point_embed*” prefix, such as*point_embedding1*,*point_embedding3_attn*, have (5+T)×256. - all variables with the “
*image_embed*” prefix, such as*image_embedding1_attn*, have shape 4096×256. - the image positional embedding
*image_pe1*has shape 4096×256.

**Trainable parameters
**Trainable parameters live in the layernorm operations, such as *self.norm1, *and the mlp operations, such as *self.mlp.*

# The Attention.forward mechanism

The three attention operations from above use the same attention mechanism, implemented in the Attention.forward method. That’s the final piece of code we need to dive into.

Because of the three attentions above, we need to dive into Attention.forward three times to understand this attention mechanism. Accordingly, I will provide three versions of the code, with each version having more meaningful variable names specific to its attention kind, namely the prompt token self attention kind, token to image kind and image to token kind.

## Attention.forward for the sparse token self attention kind

Since this is a code version for self attention from *point_embedding* to *point_embedding*, it only needs on passed-in parameter *point_embedding *with shape (5+T)×256.

Line 315 uses a linear layer to project the *point_embedding* tensor to the *q* tensor of the same shape (5+T)×256. Note here T is the number of sparse prompt clicks, such as guidance clicks or corners of bounding boxes. Since a user can provide more than one clicks or bounding boxes, T is a varying number.

Line 316~317 performs the same linear projections.

**How does this linear layer handle input of varying size T?
**Here is how it works: the (5+T) part serves as the batch dimension to the linear layer. The linear layer itself has fixed size for its input 256. In other words, for each token in this (5+T) batch dimension, the same linear layer, hence the same set of trainable parameters, is used to perform the projection.

The other part of the code only uses linear algebra operations such as matrix multiplication and softmax. These linear algebra operations don’t involve any trainable parameters, so a varying T doesn’t matter. Just like the PyTorch’s matrix multiplication method *matmul* support matrices of arbitrary sizes.

The linear projection *out_proj* at line 334, it uses the batch dimension to handle the varying T, the same way as the *q_proj* projection does.

**multi-head splitting
**Line 319 performs multi-head splitting. It reshapes a (5+T)×256 tensor into a 8×(5+T)×32 tensor, with the first dimension 8 being the number of heads. The reason for the multi-heads splitting is to achieve better parallel computation, since the 8 heads can be processed in parallel. The downside is that only information from the same head is blended using attention.

The following flowchart describes the main matrix operations. Note the matrix grids in the flowchart don’t match their actual shapes, they are for demonstration purpose only.

**The q2×k3 multiplication**

The *q2* tensor is the *point_embedding*, and the *k3* tensor is the transposed *point_embedding*. The shape of *q2* is (5+T)×32 because the drawing is only for one head, out of eight heads.

The resulting *attn* tensor is interpreted as the pairwise point similarity matrix. This is because an entry in the *attn* matrix is the dot product of a row from *q2* and a column from *k3*. Dot product meansures similarity between two vectors. Here a vector is the embedding for a token (guidance click, bounding box corner, iou token, mask token). The dot product between two token embeddings tells us how similar these two tokens are. In the flowchart, I used the red vectors to demonstrate this dot product:

Note the dot product is also the un-normalised cosine similarity. You may ask why we don’t compute cosine similarity here. We don’t have to, there are various layernorm operators applied to these matrices later on for that purpose.

**The attn×v2 multiplication**

This multiplication multiplies the pairwise token similarity matrix with the point embedding matrix and we call the result the self attended point embedding. To understand why we call it that, let’s again have a look at how an entry of the resulting matrix is computed:

An entry in the resulting out1 matrix is a weighted sum of the information from all tokens at a feature dimension (the feature dimension in a single head is 32). The weights in the weighted sum is coming from a row in the attn matrix. This weights describe similarity between every pair of tokens. This weighted sum interpretation is exactly the intuition of the attention mechanism.

**Trainable parameters
**Trainable parameters live in the input projection operations *q_proj*, *k_proj*, *v_proj* and the output projection operation *out_proj*.

## Attention.forward for cross attention from token embedding to image embedding kind

The following snippet is for the cross attention from token embedding to image embedding, which produces an attended token embedding. Since it it a cross attention, the code has several passed-in arguments.

Note the linear projections between line 316~318 reduce the channel size from 256 to 128. This is different from the token self attention case, where the linear projections there don’t change the channel size.

Aft the *q2×k3* matrix multiplication An entry in the *attn* matrix is the dot product between the embedding of a point and the embedding of a pixel. So the *attn* is the pairwise similarity matrix between points and pixels. Note that the shape if attn is (5+T)×4096, it is not a square matrix because there are much less points, (5+T) of them, than the number of pixels, 4096 of them.

Then the *attn×v2* multiplication uses weights from the pairwise similarity matrix to sum up rows in the *v2* matrix, or equivalently, the *image_embedding* matrix. The result of this matrix multiplication is the token embedding attended to image embedding, that is, the *out1* matrix.

## Attention.forward for cross attention from image embedding to token imbedding kind

Note again, line 322 to 324 reduced the channel dimension from 256 to 128.

The mechanism is the same as the other cross attention describe before, so I won’t repeat here.

# Conclusion

This article describes how the SAM’s decoder works using code snippets and flowcharts. In a future article, I will cover how SAM’s encoder works.