search
HomeTechnology peripheralsAIThe PyTorch team rewrote the 'split everything' model, which is 8 times faster than the original implementation

How should we optimize Meta's "split everything" model? This blog written by the PyTorch team will help you answer it from simple to deep.

From the beginning of the year to now, generative AI has developed rapidly. But many times, we have to face a difficult problem: how to speed up the training, reasoning, etc. of generative AI, especially when using PyTorch.

In this article, researchers from the PyTorch team provide us with a solution. The article focuses on how to use pure native PyTorch to accelerate generative AI models. It also introduces new PyTorch features and practical examples of how to combine them.

What was the result? The PyTorch team said they rewrote Meta's "Split Everything" (SAM) model, resulting in code that is 8 times faster than the original implementation without losing accuracy, all optimized using native PyTorch.

The PyTorch team rewrote the split everything model, which is 8 times faster than the original implementation

Blog address: https://pytorch.org/blog/accelerating-generative-ai/

See By the end of this article, you will know:

  • Torch.compile: PyTorch model compiler, PyTorch 2.0 has added a new function called torch. compile (), can accelerate the existing model through one line of code;
  • GPU quantization: accelerate the model by reducing the calculation accuracy;
  • SDPA (Scaled Dot Product Attention): A memory-efficient attention implementation;
  • Semi-structured (2:4) Sparseness: A sparse memory format optimized for GPU;
  • Nested Tensor: Nested Tensor packs {tensor, mask} together to batch non-uniformly sized data into a single tensor, such as images of different sizes;
  • Triton Custom Operations: Use the Triton Python DSL to write GPU operations and easily integrate them into various components of PyTorch through custom operator registration.

The PyTorch team rewrote the split everything model, which is 8 times faster than the original implementation

#                                                                                                                                                                                                                               to be increased throughput and reduced memory overhead brought by PyTorch’s native features.

SAM was proposed by Meta. For more information about this research, please refer to "
CV no longer exists? Released by Meta "Split everything" AI model, CV may usher in GPT-3 moment".

The PyTorch team rewrote the split everything model, which is 8 times faster than the original implementation

Next, the article introduces the SAM optimization process, including performance analysis, bottleneck identification, and how to integrate these new features into PyTorch to solve these problems faced by SAM. In addition, this article also introduces some new features of PyTorch: torch.compile, SDPA, Triton kernels, Nested Tensor, and semi-structured sparsity.

The content of this article is in-depth step by step. At the end of the article, the fast version of SAM will be introduced. Interested friends can download it on GitHub. In addition, this article also uses the Perfetto UI to The data is visualized to illustrate the application value of each PyTorch feature.

GitHub address: https://github.com/pytorch-labs/segment-anything-fast

Rewriting of the SAM that splits all models

The study shows that the SAM baseline data type used in this article is float32 dtype, the batch size is 1, and the The results of viewing the kernel trace with PyTorch Profiler are as follows:

The PyTorch team rewrote the split everything model, which is 8 times faster than the original implementation

This article found that SAM has two places that can be optimized:

#The first is the long call to aten::index, which is caused by the tensor index operation (such as []) caused by the underlying call. However, the actual time the GPU spends on aten::index is relatively low. The reason is that during the process of starting two cores, aten::index blocks cudaStreamSynchronize between the two. This means that the CPU waits for the GPU to finish processing until the second core is launched. Therefore, in order to optimize SAM, this paper believes that one should strive to eliminate blocking GPU synchronization that causes idle time.

The second is that SAM spends a lot of GPU time in matrix multiplications (dark green in the image above), which is common in Transformers. If we could reduce the amount of GPU time a SAM model spends on matrix multiplications, we could significantly speed up SAM.

Next, this article uses SAM’s throughput (img/s) and memory overhead (GiB) to establish a baseline. After that comes the optimization process.

The PyTorch team rewrote the split everything model, which is 8 times faster than the original implementation

Bfloat16 half precision (plus GPU sync and batch processing)

To solve To solve the above problem, that is, making matrix multiplication take less time, this article turns to bfloat16. Bfloat16 is a commonly used half-precision type that can save a lot of computing time and memory by reducing the precision of each parameter and activation.

The PyTorch team rewrote the split everything model, which is 8 times faster than the original implementation

Use bfloat16 to replace the padding type

Remove GPU synchronization, this article found that there are two places that can be optimized.

The PyTorch team rewrote the split everything model, which is 8 times faster than the original implementation

The PyTorch team rewrote the split everything model, which is 8 times faster than the original implementation

Specifically (refer to the picture above for easier understanding, the variable names that appear are all in the code), the study found that in SAM In the image encoder, there are variables q_coords and k_coords that act as coordinate scalers. These variables are allocated and processed on the CPU. However, once these variables are used to index in rel_pos_resized, these indexing operations will automatically move these variables to the GPU, and this copy will cause GPU synchronization. To solve the above problem, the study noted that this part can be solved by rewriting it using torch.where as shown above.

Kernel Trace

After applying these changes, this article noted that a single kernel There is a significant time interval between calls, especially in small batches (here 1). In order to gain a deeper understanding of this phenomenon, this article begins with a performance analysis of SAM inference with a batch size of 8:

The PyTorch team rewrote the split everything model, which is 8 times faster than the original implementation

When looking at the time spent per core, this article It is observed that most of the GPU time of SAM is spent on elementwise kernels and softmax operations.

Now you can see that the relative cost of matrix multiplication is much smaller.

The PyTorch team rewrote the split everything model, which is 8 times faster than the original implementation

Combining GPU synchronization and bfloat16 optimization, SAM performance is improved by 3 times.

The PyTorch team rewrote the split everything model, which is 8 times faster than the original implementation

Torch.compile(graph breaks and CUDA graphs)

This article found that there are many small operations in the process of in-depth study of SAM. They It is believed that using a compiler to fuse operations has great benefits, so PyTorch has made the following optimizations for torch.compile:

  • Change nn.LayerNorm or Sequences of operations such as nn.GELU are fused into a single GPU kernel;
  • #Fuse operations immediately following the matrix multiplication kernel to reduce the number of GPU kernel calls.

#With these optimizations, the research reduces the number of GPU global memory roundtrips, resulting in faster inference. We can now try torch.compile on SAM’s image encoder. To maximize performance, this article uses some advanced compilation techniques:

The PyTorch team rewrote the split everything model, which is 8 times faster than the original implementation

Kernel Tracing

The PyTorch team rewrote the split everything model, which is 8 times faster than the original implementation

The results show that torch.compile works very well.

The PyTorch team rewrote the split everything model, which is 8 times faster than the original implementation

#It can be observed that softmax takes up a large part of the time, followed by various GEMM variants. The following measurements are for batch sizes of 8 and above.

The PyTorch team rewrote the split everything model, which is 8 times faster than the original implementation

SDPA: scaled_dot_product_attention

Next, this article conducts a review on SDPA (scaled_dot_product_attention) For the experiment, the focus of the research is the attention mechanism. In general, native attention mechanisms scale quadratically with sequence length in time and memory. PyTorch's SDPA operations are built on the memory-efficient attention principles of Flash Attention, FlashAttentionV2, and xFormer, which can significantly speed up GPU attention. Combined with torch.compile, this operation allows the expression and fusion of a common pattern in variants of MultiheadAttention. After a small change, the model can now use scaled_dot_product_attention.

The PyTorch team rewrote the split everything model, which is 8 times faster than the original implementation

Kernel Tracing

You can now see memory efficient attention kernel occupancy Significant computation time on GPU:

The PyTorch team rewrote the split everything model, which is 8 times faster than the original implementation

Using PyTorch's native scaled_dot_product_attention, batch sizes can be significantly increased. The graph below shows the changes for batch sizes of 32 and above.

The PyTorch team rewrote the split everything model, which is 8 times faster than the original implementation

After that, the research also experimented with Triton, NestedTensor, batch processing Predict_torch, int8 quantization, semi-structured (2:4) sparsity and other operations.

For example, this article uses a custom positional Triton kernel and observes measurement results with a batch size of 32.

The PyTorch team rewrote the split everything model, which is 8 times faster than the original implementation

Using Nested Tensor, variations in batch size 32 and above.

The PyTorch team rewrote the split everything model, which is 8 times faster than the original implementation

#Measurements for batch sizes of 32 and above after adding quantization.

The PyTorch team rewrote the split everything model, which is 8 times faster than the original implementation

The end of the article is semi-structured sparsity. The study shows that matrix multiplication is still a bottleneck that needs to be faced. The solution is to use sparsification to approximate matrix multiplication. By sparse matrices (i.e. zeroing out the values) fewer bits can be used to store weights and activation tensors. The process of setting which weights in a tensor is set to zero is called pruning. Pruning out smaller weights can potentially reduce model size without significant loss of accuracy.

There are many methods of pruning, ranging from completely unstructured to highly structured. While unstructured pruning theoretically has minimal impact on accuracy, GPUs, although very efficient at doing large dense matrix multiplications, may suffer significant performance degradation in sparse cases. A pruning method recently supported by PyTorch aims to strike a balance called semi-structured (or 2:4) sparsity. This sparse storage reduces the original tensor by 50% while producing a dense tensor output. See illustration below.

The PyTorch team rewrote the split everything model, which is 8 times faster than the original implementation

#To use this sparse storage format and the associated fast kernels, the next thing to do is to prune the weights. This article selects the smallest two weights for pruning at a sparsity of 2:4. Changing the weights from the default PyTorch ("strided") layout to this new semi-structured sparse layout is easy. To implement apply_sparse (model), only 32 lines of Python code are needed:

The PyTorch team rewrote the split everything model, which is 8 times faster than the original implementation

At a sparsity of 2:4, this article observes vit_b and SAM when the batch size is 32 Peak performance:

The PyTorch team rewrote the split everything model, which is 8 times faster than the original implementation

Finally, to summarize this article in one sentence: This article introduces the fastest Segment Anything implementation on PyTorch so far, with the help of a series of new officially released Function, this article rewrites the original SAM in pure PyTorch without losing accuracy.

Interested readers can check out the original blog for more information.

Reference link: https://pytorch.org/blog/accelerating-generative-ai/

The above is the detailed content of The PyTorch team rewrote the 'split everything' model, which is 8 times faster than the original implementation. For more information, please follow other related articles on the PHP Chinese website!

Statement
This article is reproduced at:机器之心. If there is any infringement, please contact admin@php.cn delete
Exploring the Capabilities of Google's Gemma 2 ModelsExploring the Capabilities of Google's Gemma 2 ModelsApr 22, 2025 am 11:26 AM

Google's Gemma 2: A Powerful, Efficient Language Model Google's Gemma family of language models, celebrated for efficiency and performance, has expanded with the arrival of Gemma 2. This latest release comprises two models: a 27-billion parameter ver

The Next Wave of GenAI: Perspectives with Dr. Kirk Borne - Analytics VidhyaThe Next Wave of GenAI: Perspectives with Dr. Kirk Borne - Analytics VidhyaApr 22, 2025 am 11:21 AM

This Leading with Data episode features Dr. Kirk Borne, a leading data scientist, astrophysicist, and TEDx speaker. A renowned expert in big data, AI, and machine learning, Dr. Borne offers invaluable insights into the current state and future traje

AI For Runners And Athletes: We're Making Excellent ProgressAI For Runners And Athletes: We're Making Excellent ProgressApr 22, 2025 am 11:12 AM

There were some very insightful perspectives in this speech—background information about engineering that showed us why artificial intelligence is so good at supporting people’s physical exercise. I will outline a core idea from each contributor’s perspective to demonstrate three design aspects that are an important part of our exploration of the application of artificial intelligence in sports. Edge devices and raw personal data This idea about artificial intelligence actually contains two components—one related to where we place large language models and the other is related to the differences between our human language and the language that our vital signs “express” when measured in real time. Alexander Amini knows a lot about running and tennis, but he still

Jamie Engstrom On Technology, Talent And Transformation At CaterpillarJamie Engstrom On Technology, Talent And Transformation At CaterpillarApr 22, 2025 am 11:10 AM

Caterpillar's Chief Information Officer and Senior Vice President of IT, Jamie Engstrom, leads a global team of over 2,200 IT professionals across 28 countries. With 26 years at Caterpillar, including four and a half years in her current role, Engst

New Google Photos Update Makes Any Photo Pop With Ultra HDR QualityNew Google Photos Update Makes Any Photo Pop With Ultra HDR QualityApr 22, 2025 am 11:09 AM

Google Photos' New Ultra HDR Tool: A Quick Guide Enhance your photos with Google Photos' new Ultra HDR tool, transforming standard images into vibrant, high-dynamic-range masterpieces. Ideal for social media, this tool boosts the impact of any photo,

What are the TCL Commands in SQL? - Analytics VidhyaWhat are the TCL Commands in SQL? - Analytics VidhyaApr 22, 2025 am 11:07 AM

Introduction Transaction Control Language (TCL) commands are essential in SQL for managing changes made by Data Manipulation Language (DML) statements. These commands allow database administrators and users to control transaction processes, thereby

How to Make Custom ChatGPT? - Analytics VidhyaHow to Make Custom ChatGPT? - Analytics VidhyaApr 22, 2025 am 11:06 AM

Harness the power of ChatGPT to create personalized AI assistants! This tutorial shows you how to build your own custom GPTs in five simple steps, even without coding skills. Key Features of Custom GPTs: Create personalized AI models for specific t

Difference Between Method Overloading and OverridingDifference Between Method Overloading and OverridingApr 22, 2025 am 10:55 AM

Introduction Method overloading and overriding are core object-oriented programming (OOP) concepts crucial for writing flexible and efficient code, particularly in data-intensive fields like data science and AI. While similar in name, their mechanis

See all articles

Hot AI Tools

Undresser.AI Undress

Undresser.AI Undress

AI-powered app for creating realistic nude photos

AI Clothes Remover

AI Clothes Remover

Online AI tool for removing clothes from photos.

Undress AI Tool

Undress AI Tool

Undress images for free

Clothoff.io

Clothoff.io

AI clothes remover

Video Face Swap

Video Face Swap

Swap faces in any video effortlessly with our completely free AI face swap tool!

Hot Tools

SublimeText3 English version

SublimeText3 English version

Recommended: Win version, supports code prompts!

mPDF

mPDF

mPDF is a PHP library that can generate PDF files from UTF-8 encoded HTML. The original author, Ian Back, wrote mPDF to output PDF files "on the fly" from his website and handle different languages. It is slower than original scripts like HTML2FPDF and produces larger files when using Unicode fonts, but supports CSS styles etc. and has a lot of enhancements. Supports almost all languages, including RTL (Arabic and Hebrew) and CJK (Chinese, Japanese and Korean). Supports nested block-level elements (such as P, DIV),

SublimeText3 Mac version

SublimeText3 Mac version

God-level code editing software (SublimeText3)

MinGW - Minimalist GNU for Windows

MinGW - Minimalist GNU for Windows

This project is in the process of being migrated to osdn.net/projects/mingw, you can continue to follow us there. MinGW: A native Windows port of the GNU Compiler Collection (GCC), freely distributable import libraries and header files for building native Windows applications; includes extensions to the MSVC runtime to support C99 functionality. All MinGW software can run on 64-bit Windows platforms.

Atom editor mac version download

Atom editor mac version download

The most popular open source editor