What on earth is JAX?

JAX is Autograd and XLA, brought together for high-performance machine learning research.

Can you fully understand this sentence? Well, if can, you don’t need to read this article :)

Most deep learning researchers use deep learning frameworks like TensorFlow and PyTorch for their work. Thanks to the advent of the DNN frameworks, anyone (maybe) can implement a DNN model with few (maybe) effort. There are several reasons that we love these frameworks.

  1. DNN frameworks provide simple interface to define a DNN model.
  2. DNN frameworks enable us to not pay attention to the backend.
  3. DNN frameworks are fast enough.

JAX is another deep learning framework. But, why do we need a new framework? This article is going to explain why we should JAX and how we can exploit this quite unfamiliar framework.

TensorFlow vs. PyTorch

In order to further explain JAX, I would like to explain how TensorFlow and PyTorch are different from each other. The major difference between them is programming model paradigm. TensorFlow’s programming model is declarative, while PyTorch’s programming model is imperative.

Most deep learning models are represented with data flow graph, because control flow (if or while statement) barely appears. So, TensorFlow takes a strategy that programmers define their model architecture first, and then the real computation is conducted on the definition statically. So, it is often called define-and-run. Whereas, PyTorch’s programming model is define-by-run, which enables programmers to run the computation directly on the tensors. By providing an intuitive

XLA (Accelerated Linear Algebra)

XLA is a domain-specific compiler for linear algebra that can accelerate TensorFlow models with potentially no source code changes.

XLA is a compiler that is specialized for deep learning. Because, Te

XLA is a deep learning compiler Operation fusion Progressive lowering leveraging MLIR

So, why JAX?

When JAX?