Home >Technology peripherals >AI >Google is ecstatic: JAX performance surpasses Pytorch and TensorFlow! It may become the fastest choice for GPU inference training

Google is ecstatic: JAX performance surpasses Pytorch and TensorFlow! It may become the fastest choice for GPU inference training

王林
王林forward
2024-04-01 19:46:111317browse

The performance of JAX, promoted by Google, has surpassed that of Pytorch and TensorFlow in recent benchmark tests, ranking first in 7 indicators.

Google is ecstatic: JAX performance surpasses Pytorch and TensorFlow! It may become the fastest choice for GPU inference training

And the test was not completed on the TPU with the best JAX performance.

Google is ecstatic: JAX performance surpasses Pytorch and TensorFlow! It may become the fastest choice for GPU inference training

Although now among developers, Pytorch is still more popular than Tensorflow.

Google is ecstatic: JAX performance surpasses Pytorch and TensorFlow! It may become the fastest choice for GPU inference training

But in the future, perhaps more large models will be trained and run based on the JAX platform.

Google is ecstatic: JAX performance surpasses Pytorch and TensorFlow! It may become the fastest choice for GPU inference training

Model

Recently, the Keras team implemented and paired the three backends (TensorFlow, JAX, PyTorch) with native PyTorch TensorFlow's Keras 2 was benchmarked.

First, they selected a set of mainstream computer vision and natural language processing models for generative and non-generative AI tasks:

Google is ecstatic: JAX performance surpasses Pytorch and TensorFlow! It may become the fastest choice for GPU inference training

For the Keras version of the model, it is built using the existing implementations in KerasCV and KerasNLP. For the native PyTorch version, we chose the most popular options on the Internet:

- BERT, Gemma, Mistral from HuggingFace Transformers

- StableDiffusion from HuggingFace Diffusers

- SegmentAnything from Meta

They call this set of models "Native PyTorch" to distinguish it from the Keras 3 version that uses the PyTorch backend.

They used synthetic data for all benchmarks and used bfloat16 precision in all LLM training and inference, while using LoRA (fine-tuning) in all LLM training.

According to the suggestion of the PyTorch team, they used torch.compile(model, mode="reduce-overhead") in the native PyTorch implementation (except for Gemma and Mistral training due to incompatibility ).

To measure out-of-the-box performance, they use high-level APIs (such as HuggingFace’s Trainer(), standard PyTorch training loops, and Keras model.fit()) and minimize configuration.

Hardware configuration

All benchmark tests were conducted using Google Cloud Compute Engine, configured as: an NVIDIA A100 GPU with 40GB of video memory, 12 virtual CPUs and 85GB Host memory.

Benchmark Results

Table 2 shows the benchmark results in steps/ms. Each step involves training or prediction on a single batch of data.

The result is the average of 100 steps, but the first step is excluded because the first step includes model creation and compilation, which takes extra time.

To ensure a fair comparison, the same batch size is used for the same model and task (whether training or inference).

However, for different models and tasks, due to their different scale and architecture, the data batch size can be adjusted as needed to avoid memory overflow due to being too large, or The batch size is too small and the GPU is underutilized.

A batch size that is too small can also make PyTorch appear slower because it increases Python overhead.

For the large language models (Gemma and Mistral), the same batch size was also used when testing because they are the same type of model with a similar number of parameters (7B).

Considering users’ needs for single-batch text generation, a benchmark test was also conducted on text generation with a batch size of 1.

Google is ecstatic: JAX performance surpasses Pytorch and TensorFlow! It may become the fastest choice for GPU inference training

Key findings

Discover that 1

There is no "optimal" end.

The three backends of Keras each have their own strengths. The important thing is that in terms of performance, no one backend can always win.

Choosing which backend is the fastest often depends on the architecture of the model.

This point highlights the importance of choosing different frameworks to pursue optimal performance. Keras 3 makes it easy to switch backends to find the best fit for your model.

Found 2

The performance of Keras 3 generally exceeds the standard implementation of PyTorch.

Compared to native PyTorch, Keras 3 has a significant improvement in throughput (steps/ms).

In particular, in 5 of the 10 test tasks, the speed increase exceeded 50%. Among them, the highest reached 290%.

Google is ecstatic: JAX performance surpasses Pytorch and TensorFlow! It may become the fastest choice for GPU inference training

If it is 100%, it means that Keras 3 is twice as fast as PyTorch; if it is 0%, it means that the performance of the two is equivalent

Discover 3

Keras 3 delivers best-in-class performance “out of the box”.

In other words, all Keras models participating in the test have not been optimized in any way. In contrast, when using native PyTorch implementation, users usually need to perform more performance optimizations on their own.

In addition to the data shared above, it was also noticed during the test that when upgrading the StableDiffusion inference function of HuggingFace Diffusers from version 0.25.0 to 0.3.0, the performance improved by more than 100% .

Similarly, in HuggingFace Transformers, upgrading Gemma from version 4.38.1 to version 4.38.2 also significantly improved performance.

These performance improvements highlight HuggingFace’s focus and efforts in performance optimization.

For some models with less manual optimization, such as SegmentAnything, the implementation provided by the study author is used. In this case, the performance gap compared to Keras is larger than most other models.

This shows that Keras can provide excellent out-of-the-box performance, and users can enjoy fast model running speeds without having to delve into all optimization techniques.

Found 4

Keras 3 consistently outperforms Keras 2.

For example, SegmentAnything’s inference speed has increased by an astonishing 380%, StableDiffusion’s training processing speed has increased by more than 150%, and BERT’s training processing speed has also increased by more than 100%.

This is mainly because Keras 2 directly uses more TensorFlow fusion operations in some cases, which may not be the best choice for XLA compilation.

It’s worth noting that even just upgrading to Keras 3 and continuing to use the TensorFlow backend can result in significant performance improvements.

Google is ecstatic: JAX performance surpasses Pytorch and TensorFlow! It may become the fastest choice for GPU inference training

Conclusion

The performance of the framework depends largely on the specific model used.

Keras 3 can help choose the fastest framework for the task, and this choice will almost always outperform Keras 2 and PyTorch implementations.

More importantly, Keras 3 models provide excellent out-of-the-box performance without complex underlying optimizations.

The above is the detailed content of Google is ecstatic: JAX performance surpasses Pytorch and TensorFlow! It may become the fastest choice for GPU inference training. For more information, please follow other related articles on the PHP Chinese website!

Statement:
This article is reproduced at:51cto.com. If there is any infringement, please contact admin@php.cn delete