Home  >  Article  >  Technology peripherals  >  TensorFlow, PyTorch, and JAX: Which deep learning framework is better for you?

TensorFlow, PyTorch, and JAX: Which deep learning framework is better for you?

WBOY
WBOYforward
2023-04-09 22:01:041412browse

TensorFlow, PyTorch, and JAX: Which deep learning framework is better for you?

Translator | Zhu Xianzhong

Reviewer | Ink color

Deep learning affects our lives in various forms every day. Whether it’s Siri, Alexa, real-time translation apps on your phone based on user voice commands, or computer vision technology that powers smart tractors, warehouse robots, and self-driving cars, every month seems to bring new advancements. Almost all of these deep learning applications are written in these three frameworks: TensorFlow, PyTorch, or JAX.

So, which deep learning frameworks should you use? In this article, we will perform a high-level comparison of TensorFlow, PyTorch, and JAX. Our goal is to give you an idea of ​​the types of apps that play to their strengths, while of course taking into account factors like community support and ease of use.

Should you use TensorFlow?

"No one ever got fired for buying IBM" was a slogan in the computer world in the 1970s and 1980s. At the beginning of this century, the same was true for deep learning using TensorFlow. But as we all know, by the time we entered the 1990s, IBM had been "put on the back burner." So, is TensorFlow still competitive today, 7 years after its initial release in 2015, and into the new decade ahead?

certainly. TensorFlow hasn't always stood still. First, TensorFlow 1.x builds static graphs in a non-Pythonic way, but in TensorFlow 2.x, you can also build models using eager mode to evaluate operations immediately, which makes it feel like More like PyTorch. At the high level, TensorFlow provides Keras to facilitate development; at the bottom level, it provides the XLA (Accelerated Linear Algebra, accelerated linear algebra) optimizing compiler to increase speed. XLA plays a magical role in improving GPU performance. It is the primary method of leveraging the power of Google's TPU (Tensor Processing Units), providing unparalleled performance for large-scale model training.

Secondly, TensorFlow has strived over the years to be as good at everything as possible. For example, do you want to serve models in a well-defined and repeatable way on a mature platform? TensorFlow is ready to serve. Do you want to relocate model deployment to the web, low-power computing such as smartphones, or resource-constrained devices such as the Internet of Things? At this point, both TensorFlow.js and TensorFlow Lite are very mature.

Obviously, considering that Google is still 100% using TensorFlow to run its production deployments, you can be sure that TensorFlow will be able to meet the scale needs of users.

However, there are indeed some factors in recent projects that cannot be ignored. In short, upgrading a project from TensorFlow 1.x to TensorFlow 2.x is actually very cruel. Some companies simply decide to port the code to the PyTorch framework, considering the effort required to update the code to work properly on the new version. In addition, TensorFlow has also lost momentum in the scientific research field, which started to prefer the flexibility provided by PyTorch a few years ago, which has led to the continuous decline of TensorFlow's use in research papers.

In addition, the "Keras incident" did not play any role. Keras became an integrated part of the TensorFlow distribution two years ago, but has recently been pulled back into a separate library with its own release plan. Of course, excluding Keras won't affect developers' daily lives, but having such a dramatic change in a small updated version of the framework doesn't inspire confidence among programmers to use the TensorFlow framework.

Having said that, TensorFlow is indeed a reliable framework. It has an extensive deep learning ecosystem, and users can build applications and models of all scales on TensorFlow. If we do this, there will be many good companies to work with. But today, TensorFlow may not be the first choice.

Should you use PyTorch?

PyTorch is no longer the “upstart” following TensorFlow, but is now a major force in the field of deep learning, perhaps mainly used for research, but increasingly used for production applications. With eager mode becoming the default approach for development in TensorFlow and PyTorch, the more Pythonic approach provided by PyTorch's autograd appears to be winning the war against static graphs.

Unlike TensorFlow, PyTorch's core code has not experienced any major outages since the variable API was deprecated in version 0.4. Previously, variables required automatically generated tensors, but now, everything is a tensor. But that's not to say there aren't mistakes everywhere. For example, if you have been using PyTorch to train across multiple GPUs, you may have encountered differences between DataParallel and the newer DistributedDataParaller. You should always use DistributedDataParallel, but there's really nothing against using DataParaller.

Although PyTorch has always lagged behind TensorFlow and JAX in terms of XLA/TPU support, as of 2022, the situation has improved a lot. PyTorch now supports access to TPU virtual machines, support for legacy TPU nodes, and support for simple command line deployment of code running on CPU, GPU, or TPU without requiring code changes. If you don't want to deal with some of the boilerplate code that PyTorch often makes you write, then you can turn to higher-level extensions like Pytorche Lightning, which let you focus on actual work instead of rewriting training loops. On the other hand, although work on PyTorch Mobile continues, it is far less mature than TensorFlow Lite.

On the production side, PyTorch can now be integrated with framework-agnostic platforms like Kubeflow, and the TorchServe project handles deployment details like scaling, metrics, and batch inference — all maintained by PyTorch developers themselves All the benefits of MLOps are available in a small package. On the other hand, does PyTorch support scaling? no problem! Meta has been running PyTorch in production for years; so anyone who tells you that PyTorch can't handle large-scale workloads is lying. Nonetheless, there is a situation where PyTorch may not be as user-friendly as JAX, especially when it comes to very heavy training that requires a large number of GPUs or TPUs.

Finally, there is still a thorny issue that people don't want to mention-PyTorch's popularity in the past few years is almost inseparable from the success of Hugging Face's Transformers library. Yes, Transformers now also supports TensorFlow and JAX, but it was originally a PyTorch project and is still tightly integrated with the framework. With the rise of the Transformer architecture, PyTorch's flexibility for research, and the ability to introduce so many new models within days or hours of release through Hugging Face's Model Center, it's easy to see why PyTorch is in these areas. So popular.

Should you use JAX?

If you are not interested in TensorFlow, then Google may have other services for you. JAX is a deep learning framework built, maintained, and used by Google, but it is not an official Google product. However, if you pay attention to Google/DeepMind papers and product releases over the past year or so, you will notice that a lot of Google's research has moved to JAX. So while JAX is not an "official" Google product, it is something that Google researchers use to push the boundaries.

What exactly is JAX? A simple way to think about JAX is: imagine a GPU/TPU accelerated version of NumPy that can magically vectorize Python functions with "a magic wand" and handle the calculation of derivatives of all these functions. Finally, it provides a just-in-time (JIT: ​​Just-In-Time) component for fetching code and optimizing it for the XLA (Accelerated Linear Algebra) compiler, thereby significantly improving the performance of TensorFlow and PyTorch. Some code currently executes four to five times faster just by reimplementing it in JAX without any real optimization work.

Considering that JAX works at the NumPy level, JAX code is written at a much lower level than TensorFlow/Keras (or even PyTorch). Happily, there is a small but growing ecosystem surrounding JAX with some expansion. Do you want to use the neural network library? sure. Among them are Flax from Google, and Haiku from DeepMind (also Google). Additionally, Optax is available for all your optimizer needs, PIX is available for image processing, and much more. Once you use something like Flax, building neural networks becomes relatively easy to master. Note that there are still some troubling issues. For example, experienced people often talk about how JAX handles random numbers differently than many other frameworks.

So, should you convert everything to JAX and take advantage of this cutting-edge technology? This question varies from person to person. This approach is recommended if you are delving into large-scale models that require a lot of resources to train. Additionally, if you're interested in JAX for deterministic training, and other projects that require thousands of TPU Pods, it's worth a try.

Summary

So, what is the conclusion? Which deep learning framework should you use? Unfortunately, there is no single answer to this question, it all depends on the type of problem you are working on, the scale you plan to deploy the model to handle, and even the computing platform you are dealing with.

However, if you are working in the field of text and images and are doing small to medium-sized research with a view to deploying these models in production, then PyTorch is probably the best choice at the moment. Judging from the recent version, it hits the sweet spot of this type of application space.

If you need all the performance from a low-compute device, then it is recommended that you use TensorFlow and the extremely robust TensorFlow Lite package. Finally, if you are looking at training models with tens, hundreds of billions, or more parameters, and you are training them primarily for research purposes, it might be time to give JAX a try.

Original link: https://www.infoworld.com/article/3670114/tensorflow-pytorch-and-jax-choosing -a-deep-learning-framework.html

Translator introduction

Zhu Xianzhong, 51CTO community editor, 51CTO expert blogger, lecturer, a university in Weifang Computer teacher and veteran in the field of freelance programming.

The above is the detailed content of TensorFlow, PyTorch, and JAX: Which deep learning framework is better for you?. For more information, please follow other related articles on the PHP Chinese website!

Statement:
This article is reproduced at:51cto.com. If there is any infringement, please contact admin@php.cn delete