How should we optimize Meta's "split everything" model? This blog written by the PyTorch team will help you answer it from simple to deep.
From the beginning of the year to now, generative AI has developed rapidly. But many times, we have to face a difficult problem: how to speed up the training, reasoning, etc. of generative AI, especially when using PyTorch. In this article, researchers from the PyTorch team provide us with a solution. The article focuses on how to use pure native PyTorch to accelerate generative AI models. It also introduces new PyTorch features and practical examples of how to combine them. What was the result? The PyTorch team said they rewrote Meta's "Split Everything" (SAM) model, resulting in code that is 8 times faster than the original implementation without losing accuracy, all optimized using native PyTorch.
Blog address: https://pytorch.org/blog/accelerating-generative-ai/See By the end of this article, you will know:
- Torch.compile: PyTorch model compiler, PyTorch 2.0 has added a new function called torch. compile (), can accelerate the existing model through one line of code;
- GPU quantization: accelerate the model by reducing the calculation accuracy;
- SDPA (Scaled Dot Product Attention): A memory-efficient attention implementation;
- Semi-structured (2:4) Sparseness: A sparse memory format optimized for GPU;
- Nested Tensor: Nested Tensor packs {tensor, mask} together to batch non-uniformly sized data into a single tensor, such as images of different sizes;
- Triton Custom Operations: Use the Triton Python DSL to write GPU operations and easily integrate them into various components of PyTorch through custom operator registration.
# to be increased throughput and reduced memory overhead brought by PyTorch’s native features.
SAM was proposed by Meta. For more information about this research, please refer to "
CV no longer exists? Released by Meta "Split everything" AI model, CV may usher in GPT-3 moment". Next, the article introduces the SAM optimization process, including performance analysis, bottleneck identification, and how to integrate these new features into PyTorch to solve these problems faced by SAM. In addition, this article also introduces some new features of PyTorch: torch.compile, SDPA, Triton kernels, Nested Tensor, and semi-structured sparsity.
The content of this article is in-depth step by step. At the end of the article, the fast version of SAM will be introduced. Interested friends can download it on GitHub. In addition, this article also uses the Perfetto UI to The data is visualized to illustrate the application value of each PyTorch feature.
GitHub address: https://github.com/pytorch-labs/segment-anything-fast
Rewriting of the SAM that splits all modelsThe study shows that the SAM baseline data type used in this article is float32 dtype, the batch size is 1, and the The results of viewing the kernel trace with PyTorch Profiler are as follows:
This article found that SAM has two places that can be optimized: #The first is the long call to aten::index, which is caused by the tensor index operation (such as []) caused by the underlying call. However, the actual time the GPU spends on aten::index is relatively low. The reason is that during the process of starting two cores, aten::index blocks cudaStreamSynchronize between the two. This means that the CPU waits for the GPU to finish processing until the second core is launched. Therefore, in order to optimize SAM, this paper believes that one should strive to eliminate blocking GPU synchronization that causes idle time. The second is that SAM spends a lot of GPU time in matrix multiplications (dark green in the image above), which is common in Transformers. If we could reduce the amount of GPU time a SAM model spends on matrix multiplications, we could significantly speed up SAM. Next, this article uses SAM’s throughput (img/s) and memory overhead (GiB) to establish a baseline. After that comes the optimization process.
Bfloat16 half precision (plus GPU sync and batch processing) To solve To solve the above problem, that is, making matrix multiplication take less time, this article turns to bfloat16. Bfloat16 is a commonly used half-precision type that can save a lot of computing time and memory by reducing the precision of each parameter and activation.
Use bfloat16 to replace the padding type Remove GPU synchronization, this article found that there are two places that can be optimized.
Specifically (refer to the picture above for easier understanding, the variable names that appear are all in the code), the study found that in SAM In the image encoder, there are variables q_coords and k_coords that act as coordinate scalers. These variables are allocated and processed on the CPU. However, once these variables are used to index in rel_pos_resized, these indexing operations will automatically move these variables to the GPU, and this copy will cause GPU synchronization. To solve the above problem, the study noted that this part can be solved by rewriting it using torch.where as shown above. After applying these changes, this article noted that a single kernel There is a significant time interval between calls, especially in small batches (here 1). In order to gain a deeper understanding of this phenomenon, this article begins with a performance analysis of SAM inference with a batch size of 8:
When looking at the time spent per core, this article It is observed that most of the GPU time of SAM is spent on elementwise kernels and softmax operations. Now you can see that the relative cost of matrix multiplication is much smaller.
Combining GPU synchronization and bfloat16 optimization, SAM performance is improved by 3 times.
Torch.compile(graph breaks and CUDA graphs)This article found that there are many small operations in the process of in-depth study of SAM. They It is believed that using a compiler to fuse operations has great benefits, so PyTorch has made the following optimizations for torch.compile:
- Change nn.LayerNorm or Sequences of operations such as nn.GELU are fused into a single GPU kernel;
- #Fuse operations immediately following the matrix multiplication kernel to reduce the number of GPU kernel calls.
#With these optimizations, the research reduces the number of GPU global memory roundtrips, resulting in faster inference. We can now try torch.compile on SAM’s image encoder. To maximize performance, this article uses some advanced compilation techniques:
The results show that torch.compile works very well.
#It can be observed that softmax takes up a large part of the time, followed by various GEMM variants. The following measurements are for batch sizes of 8 and above.
SDPA: scaled_dot_product_attentionNext, this article conducts a review on SDPA (scaled_dot_product_attention) For the experiment, the focus of the research is the attention mechanism. In general, native attention mechanisms scale quadratically with sequence length in time and memory. PyTorch's SDPA operations are built on the memory-efficient attention principles of Flash Attention, FlashAttentionV2, and xFormer, which can significantly speed up GPU attention. Combined with torch.compile, this operation allows the expression and fusion of a common pattern in variants of MultiheadAttention. After a small change, the model can now use scaled_dot_product_attention.
You can now see memory efficient attention kernel occupancy Significant computation time on GPU:
Using PyTorch's native scaled_dot_product_attention, batch sizes can be significantly increased. The graph below shows the changes for batch sizes of 32 and above.
After that, the research also experimented with Triton, NestedTensor, batch processing Predict_torch, int8 quantization, semi-structured (2:4) sparsity and other operations. For example, this article uses a custom positional Triton kernel and observes measurement results with a batch size of 32.
Using Nested Tensor, variations in batch size 32 and above.
#Measurements for batch sizes of 32 and above after adding quantization.
The end of the article is semi-structured sparsity. The study shows that matrix multiplication is still a bottleneck that needs to be faced. The solution is to use sparsification to approximate matrix multiplication. By sparse matrices (i.e. zeroing out the values) fewer bits can be used to store weights and activation tensors. The process of setting which weights in a tensor is set to zero is called pruning. Pruning out smaller weights can potentially reduce model size without significant loss of accuracy. There are many methods of pruning, ranging from completely unstructured to highly structured. While unstructured pruning theoretically has minimal impact on accuracy, GPUs, although very efficient at doing large dense matrix multiplications, may suffer significant performance degradation in sparse cases. A pruning method recently supported by PyTorch aims to strike a balance called semi-structured (or 2:4) sparsity. This sparse storage reduces the original tensor by 50% while producing a dense tensor output. See illustration below.
#To use this sparse storage format and the associated fast kernels, the next thing to do is to prune the weights. This article selects the smallest two weights for pruning at a sparsity of 2:4. Changing the weights from the default PyTorch ("strided") layout to this new semi-structured sparse layout is easy. To implement apply_sparse (model), only 32 lines of Python code are needed:
At a sparsity of 2:4, this article observes vit_b and SAM when the batch size is 32 Peak performance:
Finally, to summarize this article in one sentence: This article introduces the fastest Segment Anything implementation on PyTorch so far, with the help of a series of new officially released Function, this article rewrites the original SAM in pure PyTorch without losing accuracy. Interested readers can check out the original blog for more information. Reference link: https://pytorch.org/blog/accelerating-generative-ai/The above is the detailed content of The PyTorch team rewrote the 'split everything' model, which is 8 times faster than the original implementation. For more information, please follow other related articles on the PHP Chinese website!