Home >Technology peripherals >AI >Fine-Tuning SAM 2 on a Custom Dataset: Tutorial
Meta's Segment Anything Model 2 (SAM 2) is the latest innovation in segmentation technology. It is Meta’s first unified model that can segment objects in both images and videos in real time.
But why fine-tune SAM 2 if it can already segment anything?
While SAM 2 is powerful out-of-the-box, its performance on rare or domain-specific tasks may not always meet expectations. Fine-tuning allows you to adapt SAM2 to your specific needs, improving its accuracy and efficiency for your particular use case.
In this article, I’ll guide you step-by-step through the fine-tuning process of SAM 2.
SAM2 is a foundation model developed by Meta for promptable visual segmentation in images and videos. Unlike its predecessor, SAM, which primarily focused on static images, SAM2 is designed to handle the complexities of video segmentation as well.
SAM2 - Task, Model, and Data (Source: Ravi et al., 2024)
It employs a transformer architecture with streaming memory, enabling real-time video processing. SAM 2's training involved a vast and varied dataset featuring the novel SA-V dataset, which includes more than 600,000 masklet annotations spanning 51,000 videos.
Its data engine, which allows for interactive data collection and model improvement, gives the model the ability to segment anything possible. This engine enables SAM 2 to continuously learn and adapt, making it more efficient at handling new and challenging data. However, for domain-specific tasks or rare objects, fine-tuning is essential to achieve optimal performance.
In the context of SAM 2, fine-tuning is the process of further training the pre-trained SAM 2 model on a specific dataset to enhance its performance for a particular task or domain. While SAM 2 is a powerful tool trained on a broad and diverse dataset, its general-purpose nature may not always yield optimal results for specialized or rare tasks.
For example, if you're working on a medical imaging project that requires the identification of specific tumor types, the model's performance might fall short due to its generalized training.
The fine-tuning process
Fine-tuning SAM 2 addresses this limitation by allowing you to adapt the model to your specific dataset. This process improves the model's accuracy and makes it more effective for your unique use case.
Here are the key benefits of fine-tuning SAM 2:
To get started with fine-tuning SAM 2, you’ll need to have the following prerequisites in place:
Software and other requirements:
The quality of your dataset is crucial for fine-tuning the SAM 2 model. High-quality annotated images or videos with accurate segmentation masks are essential to achieving optimal performance. Precise annotations enable the model to learn the correct features, leading to better segmentation accuracy and robustness in real-world applications.
The first step involves acquiring the dataset, which forms the backbone of the training process. We sourced our data from Kaggle, a reliable platform that provides a diverse range of datasets. Using the Kaggle API, we downloaded the data in the required format, ensuring that the images and corresponding segmentation masks were readily available for further processing.
After downloading the datasets, we performed the following steps:
Since the SAM2 model requires input in specific formats, we converted the data as follows:
The processed images and masks are then organized into respective folders for easy access during the fine-tuning process. Additionally, paths to these images and masks are written into a CSV file (train.csv) to facilitate data loading during training.
The final step involved validating the dataset to ensure its accuracy:
Here is a quick notebook to take you through code for dataset creation. You can either go through this data creation path or directly use any dataset available online in the same format as the one mentioned in the pre-requisites.
Segment Anything Model 2 contains several components, but the catch here for faster fine-tuning is to train only lightweight components, such as the mask decoder and prompt encoder, rather than the entire model. The steps for fine-tuning this model are as follows:
To start the fine-tuning process, we need to install the SAM-2 library, which is crucial for the Segment Anything Model (SAM2). This model is designed to handle various segmentation tasks effectively. The installation involves cloning the SAM-2 repository from GitHub and installing the necessary dependencies.
!git clone https://github.com/facebookresearch/segment-anything-2 %cd /content/segment-anything-2 !pip install -q -e .
This code snippet ensures the SAM2 library is correctly installed and ready for use in our fine-tuning workflow.
Once the SAM-2 library is installed, the next step is to acquire the dataset we’ll be using for fine-tuning. We use a dataset available on Kaggle, specifically a chest CT segmentation dataset containing images and masks of lungs, heart, and trachea.
The dataset contains:
Image from the CT Scan Dataset
In this blog, we’ll use only images and masks of lungs for segmentation. The Kaggle API allows us to download datasets directly to our environment. We start by uploading the kaggle.json file from Kaggle to access any dataset easily.
To get kaggle.json, go to the Settings tab under your user profile and select Create New Token. This will trigger the Kaggle download. json, a file containing your API credentials.
!git clone https://github.com/facebookresearch/segment-anything-2 %cd /content/segment-anything-2 !pip install -q -e .
Unzip the data:
# get dataset from Kaggle from google.colab import files files.upload() # This will prompt you to upload the kaggle.json file !mkdir -p ~/.kaggle !mv kaggle.json ~/.kaggle/ !chmod 600 ~/.kaggle/kaggle.json !kaggle datasets download -d polomarco/chest-ct-segmentation
With the dataset ready, let’s start the fine-tuning process. As I previously mentioned, the key here is to fine-tune only the lightweight components of SAM2, such as the mask decoder and prompt encoder, rather than the entire model. This approach is more efficient and requires fewer resources.
For the fine-tuning process, we need to start with pre-trained SAM2 model weights. These weights, called "checkpoints," are the starting point for further training. The checkpoints have been trained on a wide range of images, and by fine-tuning them on our specific dataset, we can achieve better performance on our target tasks.
!unzip chest-ct-segmentation.zip -d chest-ct-segmentation
In this step, we download various SAM-2 checkpoints that correspond to different model sizes (e.g., tiny, small, base_plus, large). The choice of checkpoint can be adjusted based on the computational resources available and the specific task at hand.
With the dataset downloaded, the next step is to prepare it for training. This involves splitting the dataset into training and testing sets and creating data structures that can be fed into the SAM 2 model during fine-tuning.
!wget -O sam2_hiera_tiny.pt "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt" !wget -O sam2_hiera_small.pt "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt" !wget -O sam2_hiera_base_plus.pt "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt" !wget -O sam2_hiera_large.pt "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt"
We split the dataset into a training set (80%) and a testing set (20%) to ensure that we can evaluate the model's performance after training. The training data will be used to fine-tune the SAM 2 model, while the testing data will be used for inference and evaluation.
After splitting your dataset into training and testing sets, the next step involves creating binary masks, selecting key points within these masks, and visualizing these elements to ensure the data is correctly processed.
1. Reading and resizing images: The process starts by randomly selecting an image and its corresponding mask from the dataset. The image is converted from BGR to RGB format, which is the standard color format for most deep learning models. The corresponding annotation (mask) is read in grayscale mode. Then, both the image and the annotation mask are resized to a maximum dimension of 1024 pixels, maintaining the aspect ratio to ensure that the data fits within the model's input requirements and reduces computational load.
!git clone https://github.com/facebookresearch/segment-anything-2 %cd /content/segment-anything-2 !pip install -q -e .
2. Binarization of segmentation masks: The multi-class annotation mask (which might have multiple object classes labeled with different pixel values) is converted into a binary mask. This mask highlights all the regions of interest in the image, simplifying the segmentation task to a binary classification problem: object vs. background. The binary mask is then eroded using a 5x5 kernel.
Erosion slightly reduces the mask's size, which helps avoid boundary effects when selecting points. This ensures the selected points are well within the object's interior rather than near its edges, which might be noisy or ambiguous.
Key points are selected from within the eroded mask. These points act as prompts during the fine-tuning process, guiding the model on where to focus its attention. The points are selected randomly from the interior of the objects to ensure they are representative and not influenced by noisy boundaries.
# get dataset from Kaggle from google.colab import files files.upload() # This will prompt you to upload the kaggle.json file !mkdir -p ~/.kaggle !mv kaggle.json ~/.kaggle/ !chmod 600 ~/.kaggle/kaggle.json !kaggle datasets download -d polomarco/chest-ct-segmentation
3. Visualization: This step is crucial for verifying that the data preprocessing steps have been executed correctly. By visually inspecting the points on the binarized mask, you can ensure that the model will receive appropriate input during training. Finally, the binary mask is reshaped and formatted correctly (with dimensions suitable for the model input), and the points are also reshaped for further use in the training process. The function returns the processed image, binary mask, selected points, and the number of masks found.
!unzip chest-ct-segmentation.zip -d chest-ct-segmentation
The above code returns the following figure containing the original image from the dataset along with its binarized mask and binarized mask with points.
Original image, binarized mask, and binarized mask with points for the dataset.
Fine-tuning the SAM2 model involves several steps, including loading the model, setting up the optimizer and scheduler, and iteratively updating the model weights based on the training data.
Load the model checkpoints:
!wget -O sam2_hiera_tiny.pt "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt" !wget -O sam2_hiera_small.pt "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt" !wget -O sam2_hiera_base_plus.pt "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt" !wget -O sam2_hiera_large.pt "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt"
We start by building the SAM2 model using the pre-trained checkpoints. The model is then wrapped in a predictor class, which simplifies the process of setting images, encoding prompts, and decoding masks.
We configure several hyperparameters to ensure the model learns effectively, such as the learning rate, weight decay, and gradient accumulation steps. These hyperparameters control the learning process, including how fast the model updates its weights and how it avoids overfitting. Feel free to play around with these.
!git clone https://github.com/facebookresearch/segment-anything-2 %cd /content/segment-anything-2 !pip install -q -e .
The optimizer is responsible for updating the model weights, while the scheduler adjusts the learning rate during training to improve convergence. By fine-tuning these parameters, we can achieve better segmentation accuracy.
The actual fine-tuning process is iterative, where in each step, a batch of images and masks for lungs only is passed through the model, and the loss is computed and used to update the model weights.
# get dataset from Kaggle from google.colab import files files.upload() # This will prompt you to upload the kaggle.json file !mkdir -p ~/.kaggle !mv kaggle.json ~/.kaggle/ !chmod 600 ~/.kaggle/kaggle.json !kaggle datasets download -d polomarco/chest-ct-segmentation
During each iteration, the model processes a batch of images, computes the segmentation masks, and compares them with the ground truth to calculate the loss. This loss is then used to adjust the model weights, gradually improving the model's performance. After training for about 3000 epochs, we get an accuracy (IoU - Intersection over Union) of about 72%.
The model can then be used for inference, where it predicts segmentation masks on new, unseen images. Start with the read_images and get_points helper functions to get the inference image and its mask along with key points.
!unzip chest-ct-segmentation.zip -d chest-ct-segmentation
Then load the sample images you want for inference, along with newly fine-tuned weights, and perform inference setting torch.no_grad().
!wget -O sam2_hiera_tiny.pt "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt" !wget -O sam2_hiera_small.pt "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt" !wget -O sam2_hiera_base_plus.pt "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt" !wget -O sam2_hiera_large.pt "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt"
In this step, we use the fine-tuned model to generate segmentation masks for test images. The predicted masks are then visualized alongside the original images and ground truth masks to evaluate the model's performance.
Final segmentation image on test data
Fine-tuning SAM2 offers a practical way to enhance its capabilities for specific tasks. Whether you’re working on medical imaging, autonomous vehicles, or video editing, fine-tuning allows you to use SAM2 for your unique needs. By following this guide, you can adapt SAM2 for your projects and achieve state-of-the-art segmentation results.
For more advanced use cases, consider fine-tuning additional components of SAM2, such as the image encoder. While this requires more resources, it offers greater flexibility and performance improvements.
The above is the detailed content of Fine-Tuning SAM 2 on a Custom Dataset: Tutorial. For more information, please follow other related articles on the PHP Chinese website!