Home >Technology peripherals >AI >Fine-Tuning SAM 2 on a Custom Dataset: Tutorial

Fine-Tuning SAM 2 on a Custom Dataset: Tutorial

Jennifer Aniston
Jennifer AnistonOriginal
2025-03-04 09:26:13884browse

Meta's Segment Anything Model 2 (SAM 2) is the latest innovation in segmentation technology. It is Meta’s first unified model that can segment objects in both images and videos in real time.

But why fine-tune SAM 2 if it can already segment anything?

While SAM 2 is powerful out-of-the-box, its performance on rare or domain-specific tasks may not always meet expectations. Fine-tuning allows you to adapt SAM2 to your specific needs, improving its accuracy and efficiency for your particular use case.

In this article, I’ll guide you step-by-step through the fine-tuning process of SAM 2.

What Is SAM2?

SAM2 is a foundation model developed by Meta for promptable visual segmentation in images and videos. Unlike its predecessor, SAM, which primarily focused on static images, SAM2 is designed to handle the complexities of video segmentation as well.

Fine-Tuning SAM 2 on a Custom Dataset: Tutorial

SAM2 - Task, Model, and Data (Source: Ravi et al., 2024)

It employs a transformer architecture with streaming memory, enabling real-time video processing. SAM 2's training involved a vast and varied dataset featuring the novel SA-V dataset, which includes more than 600,000 masklet annotations spanning 51,000 videos.

Its data engine, which allows for interactive data collection and model improvement, gives the model the ability to segment anything possible. This engine enables SAM 2 to continuously learn and adapt, making it more efficient at handling new and challenging data. However, for domain-specific tasks or rare objects, fine-tuning is essential to achieve optimal performance.

Why Fine-Tune SAM2?

In the context of SAM 2, fine-tuning is the process of further training the pre-trained SAM 2 model on a specific dataset to enhance its performance for a particular task or domain. While SAM 2 is a powerful tool trained on a broad and diverse dataset, its general-purpose nature may not always yield optimal results for specialized or rare tasks.

For example, if you're working on a medical imaging project that requires the identification of specific tumor types, the model's performance might fall short due to its generalized training.

Fine-Tuning SAM 2 on a Custom Dataset: Tutorial

The fine-tuning process

Fine-tuning SAM 2 addresses this limitation by allowing you to adapt the model to your specific dataset. This process improves the model's accuracy and makes it more effective for your unique use case.

Here are the key benefits of fine-tuning SAM 2:

  1. Improved accuracy: By fine-tuning the model on your specific dataset, you can significantly enhance its accuracy, ensuring better performance in your targeted application.
  2. Specialized segmentation: Fine-tuning enables the model to become adept at segmenting specific object types, visual styles, or environments that are relevant to your project, providing tailored results that a general-purpose model may not achieve.
  3. Efficiency: Fine-tuning is often more efficient than training a model from scratch. It typically requires less data and time, making it a practical solution for quickly adapting the model to new or niche tasks.

Getting Started With Fine-Tuning SAM 2: Prerequisites

To get started with fine-tuning SAM 2, you’ll need to have the following prerequisites in place:

  1. Access to the SAM 2 model and codebase: Have access to the SAM 2 model and its codebase. You can download the pre-trained SAM 2 model from Meta's GitHub repository.
  2. A suitable dataset: You'll need a dataset that includes ground truth segmentation masks. For this tutorial, we’ll be using the Chest CT Segmentation dataset, which you can download and prepare for training.
  3. Computational resources: Fine-tuning SAM 2 requires hardware with sufficient computational power. GPUs are highly recommended to ensure the process is efficient and manageable, especially when working with large datasets or complex models. In this example, an A100 GPU on Google Colab is used.

Software and other requirements:

  • Python 3.11 or higher
  • PyTorch
  • OpenCV: Install it using !pip install opencv-python

Preparing the Dataset for Fine-Tuning SAM 2

The quality of your dataset is crucial for fine-tuning the SAM 2 model. High-quality annotated images or videos with accurate segmentation masks are essential to achieving optimal performance. Precise annotations enable the model to learn the correct features, leading to better segmentation accuracy and robustness in real-world applications.

1. Data acquisition

The first step involves acquiring the dataset, which forms the backbone of the training process. We sourced our data from Kaggle, a reliable platform that provides a diverse range of datasets. Using the Kaggle API, we downloaded the data in the required format, ensuring that the images and corresponding segmentation masks were readily available for further processing.

2. Data extraction and cleaning

After downloading the datasets, we performed the following steps:

  • Unzipping and cleaning: Extract the data from the downloaded zip files and delete unnecessary files to save disk space.
  • ID extraction: Unique identifiers (IDs) for images and masks are extracted to ensure correct mapping between them during training.
  • Removing unnecessary files: Remove any noisy or irrelevant files, such as certain images with known issues, to maintain the integrity of the dataset.

3. Conversion to usable formats

Since the SAM2 model requires input in specific formats, we converted the data as follows:

  • DICOM to NumPy: The DICOM images were read and stored as NumPy arrays, which were then resized to a standard dimension of 512x512 pixels.
  • NRRD to NumPy for masks: Similarly, NRRD files containing masks for lungs, heart, and trachea were processed and saved as NumPy arrays. These masks were then reshaped to match the corresponding images.
  • Conversion to JPG/PNG: For better visualization and compatibility, the NumPy arrays were converted to JPG/PNG formats. This step included normalizing the image intensity values and ensuring the masks were correctly oriented.

4. Saving and organizing data

The processed images and masks are then organized into respective folders for easy access during the fine-tuning process. Additionally, paths to these images and masks are written into a CSV file (train.csv) to facilitate data loading during training.

5. Visualization and validation

The final step involved validating the dataset to ensure its accuracy:

  • Visualization: We visualized the image-mask pairs by overlaying the masks on the images. This helped us check the alignment and accuracy of the masks.
  • Inspection: By inspecting a few samples, we could confirm that the dataset was correctly prepared and ready for use in fine-tuning.

Here is a quick notebook to take you through code for dataset creation. You can either go through this data creation path or directly use any dataset available online in the same format as the one mentioned in the pre-requisites.

Fine-Tuning SAM2

Segment Anything Model 2 contains several components, but the catch here for faster fine-tuning is to train only lightweight components, such as the mask decoder and prompt encoder, rather than the entire model. The steps for fine-tuning this model are as follows:

Step 1: Install SAM-2

To start the fine-tuning process, we need to install the SAM-2 library, which is crucial for the Segment Anything Model (SAM2). This model is designed to handle various segmentation tasks effectively. The installation involves cloning the SAM-2 repository from GitHub and installing the necessary dependencies.

!git clone https://github.com/facebookresearch/segment-anything-2
%cd /content/segment-anything-2
!pip install -q -e .

This code snippet ensures the SAM2 library is correctly installed and ready for use in our fine-tuning workflow.

Step 2: Download the dataset

Once the SAM-2 library is installed, the next step is to acquire the dataset we’ll be using for fine-tuning. We use a dataset available on Kaggle, specifically a chest CT segmentation dataset containing images and masks of lungs, heart, and trachea.

The dataset contains:

  • images.zip: Images in RGB format
  • masks.zip: Segmentation masks in RGB format
  • train.csv: CSV file with image names

Fine-Tuning SAM 2 on a Custom Dataset: Tutorial

Image from the CT Scan Dataset

In this blog, we’ll use only images and masks of lungs for segmentation. The Kaggle API allows us to download datasets directly to our environment. We start by uploading the kaggle.json file from Kaggle to access any dataset easily.

To get kaggle.json, go to the Settings tab under your user profile and select Create New Token. This will trigger the Kaggle download. json, a file containing your API credentials.

!git clone https://github.com/facebookresearch/segment-anything-2
%cd /content/segment-anything-2
!pip install -q -e .

Unzip the data:

# get dataset from Kaggle
from google.colab import files
files.upload()  # This will prompt you to upload the kaggle.json file

!mkdir -p ~/.kaggle
!mv kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
!kaggle datasets download -d polomarco/chest-ct-segmentation

With the dataset ready, let’s start the fine-tuning process. As I previously mentioned, the key here is to fine-tune only the lightweight components of SAM2, such as the mask decoder and prompt encoder, rather than the entire model. This approach is more efficient and requires fewer resources.

Step 3: Download SAM-2 checkpoints

For the fine-tuning process, we need to start with pre-trained SAM2 model weights. These weights, called "checkpoints," are the starting point for further training. The checkpoints have been trained on a wide range of images, and by fine-tuning them on our specific dataset, we can achieve better performance on our target tasks.

!unzip chest-ct-segmentation.zip -d chest-ct-segmentation

In this step, we download various SAM-2 checkpoints that correspond to different model sizes (e.g., tiny, small, base_plus, large). The choice of checkpoint can be adjusted based on the computational resources available and the specific task at hand.

Step 4: Data preparation

With the dataset downloaded, the next step is to prepare it for training. This involves splitting the dataset into training and testing sets and creating data structures that can be fed into the SAM 2 model during fine-tuning.

!wget -O sam2_hiera_tiny.pt "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt"
!wget -O sam2_hiera_small.pt "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt"
!wget -O sam2_hiera_base_plus.pt "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt"
!wget -O sam2_hiera_large.pt "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt"

We split the dataset into a training set (80%) and a testing set (20%) to ensure that we can evaluate the model's performance after training. The training data will be used to fine-tune the SAM 2 model, while the testing data will be used for inference and evaluation.

After splitting your dataset into training and testing sets, the next step involves creating binary masks, selecting key points within these masks, and visualizing these elements to ensure the data is correctly processed. 

1. Reading and resizing images: The process starts by randomly selecting an image and its corresponding mask from the dataset. The image is converted from BGR to RGB format, which is the standard color format for most deep learning models. The corresponding annotation (mask) is read in grayscale mode. Then, both the image and the annotation mask are resized to a maximum dimension of 1024 pixels, maintaining the aspect ratio to ensure that the data fits within the model's input requirements and reduces computational load.

!git clone https://github.com/facebookresearch/segment-anything-2
%cd /content/segment-anything-2
!pip install -q -e .

2. Binarization of segmentation masks: The multi-class annotation mask (which might have multiple object classes labeled with different pixel values) is converted into a binary mask. This mask highlights all the regions of interest in the image, simplifying the segmentation task to a binary classification problem: object vs. background. The binary mask is then eroded using a 5x5 kernel.

Erosion slightly reduces the mask's size, which helps avoid boundary effects when selecting points. This ensures the selected points are well within the object's interior rather than near its edges, which might be noisy or ambiguous.

Key points are selected from within the eroded mask. These points act as prompts during the fine-tuning process, guiding the model on where to focus its attention. The points are selected randomly from the interior of the objects to ensure they are representative and not influenced by noisy boundaries.

# get dataset from Kaggle
from google.colab import files
files.upload()  # This will prompt you to upload the kaggle.json file

!mkdir -p ~/.kaggle
!mv kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
!kaggle datasets download -d polomarco/chest-ct-segmentation

3. Visualization: This step is crucial for verifying that the data preprocessing steps have been executed correctly. By visually inspecting the points on the binarized mask, you can ensure that the model will receive appropriate input during training. Finally, the binary mask is reshaped and formatted correctly (with dimensions suitable for the model input), and the points are also reshaped for further use in the training process. The function returns the processed image, binary mask, selected points, and the number of masks found.

!unzip chest-ct-segmentation.zip -d chest-ct-segmentation

The above code returns the following figure containing the original image from the dataset along with its binarized mask and binarized mask with points. 

Fine-Tuning SAM 2 on a Custom Dataset: Tutorial

Original image, binarized mask, and binarized mask with points for the dataset.

Step 5: Fine-tuning the SAM2 model

Fine-tuning the SAM2 model involves several steps, including loading the model, setting up the optimizer and scheduler, and iteratively updating the model weights based on the training data.

Load the model checkpoints:

!wget -O sam2_hiera_tiny.pt "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt"
!wget -O sam2_hiera_small.pt "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt"
!wget -O sam2_hiera_base_plus.pt "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt"
!wget -O sam2_hiera_large.pt "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt"

We start by building the SAM2 model using the pre-trained checkpoints. The model is then wrapped in a predictor class, which simplifies the process of setting images, encoding prompts, and decoding masks.

Configure hyperparameters

We configure several hyperparameters to ensure the model learns effectively, such as the learning rate, weight decay, and gradient accumulation steps. These hyperparameters control the learning process, including how fast the model updates its weights and how it avoids overfitting. Feel free to play around with these.

!git clone https://github.com/facebookresearch/segment-anything-2
%cd /content/segment-anything-2
!pip install -q -e .

The optimizer is responsible for updating the model weights, while the scheduler adjusts the learning rate during training to improve convergence. By fine-tuning these parameters, we can achieve better segmentation accuracy.

Start training

The actual fine-tuning process is iterative, where in each step, a batch of images and masks for lungs only is passed through the model, and the loss is computed and used to update the model weights.

# get dataset from Kaggle
from google.colab import files
files.upload()  # This will prompt you to upload the kaggle.json file

!mkdir -p ~/.kaggle
!mv kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
!kaggle datasets download -d polomarco/chest-ct-segmentation

During each iteration, the model processes a batch of images, computes the segmentation masks, and compares them with the ground truth to calculate the loss. This loss is then used to adjust the model weights, gradually improving the model's performance. After training for about 3000 epochs, we get an accuracy (IoU - Intersection over Union) of about 72%.

Step 6: Inference with the fine-tuned model

The model can then be used for inference, where it predicts segmentation masks on new, unseen images. Start with the read_images and get_points helper functions to get the inference image and its mask along with key points.

!unzip chest-ct-segmentation.zip -d chest-ct-segmentation

Then load the sample images you want for inference, along with newly fine-tuned weights, and perform inference setting torch.no_grad().

!wget -O sam2_hiera_tiny.pt "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt"
!wget -O sam2_hiera_small.pt "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt"
!wget -O sam2_hiera_base_plus.pt "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt"
!wget -O sam2_hiera_large.pt "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt"

In this step, we use the fine-tuned model to generate segmentation masks for test images. The predicted masks are then visualized alongside the original images and ground truth masks to evaluate the model's performance.

Fine-Tuning SAM 2 on a Custom Dataset: Tutorial

Final segmentation image on test data 

Conclusion

Fine-tuning SAM2 offers a practical way to enhance its capabilities for specific tasks. Whether you’re working on medical imaging, autonomous vehicles, or video editing, fine-tuning allows you to use SAM2 for your unique needs. By following this guide, you can adapt SAM2 for your projects and achieve state-of-the-art segmentation results.

For more advanced use cases, consider fine-tuning additional components of SAM2, such as the image encoder. While this requires more resources, it offers greater flexibility and performance improvements.

Earn a Top AI Certification

Demonstrate you can effectively and responsibly use AI.Get Certified, Get Hired

The above is the detailed content of Fine-Tuning SAM 2 on a Custom Dataset: Tutorial. For more information, please follow other related articles on the PHP Chinese website!

Statement:
The content of this article is voluntarily contributed by netizens, and the copyright belongs to the original author. This site does not assume corresponding legal responsibility. If you find any content suspected of plagiarism or infringement, please contact admin@php.cn