June 21, 2024
Torch’s design decisions have been pretty poor for a while, so I’ve been experimenting with Jax. It’s great, but only if you can actually take advantage of the speedups from jit. Below are some best practices I’ve learned from playing around with this for a while.
What Exactly is jax.jit
?
JIT stands for Just-In-Time compilation. Normally, Python runs line-by-line, interpreting everything at runtime, which is convenient for flexibility but can be slow for performance-critical code. JAX changes the game by introducing the ability to compile your Python functions with jax.jit
. This allows JAX to optimize your code ahead of time, reducing overhead during execution.
In essence, jax.jit
takes your Python code and translates it into highly optimized, compiled code using XLA (Accelerated Linear Algebra), making your functions run orders of magnitude faster—especially for large-scale tensor operations.
How to Use jax.jit
Using jax.jit
is as simple as wrapping your function with it. Here’s a basic example (try this out yourself):
Without JIT
With JIT
This small change can result in dramatic speedups, especially when running the function repeatedly. But why?
Under the Hood
When you use jax.jit
, JAX compiles your function into a single execution graph, optimizing it with XLA. XLA looks for opportunities to fuse operations together, minimize memory access, and streamline the underlying computations. This results in reduced overhead for running your function, turning Python’s interpreted sluggishness into C++-like efficiency.
It also memoizes the compiled function, meaning that if you run the function multiple times with the same shapes and data types, you’ll skip recompilation and just execute the precompiled code.
When Should You Use jax.jit
?
You might be tempted to slap jax.jit
on everything, but you will end up wasting more time debugging than if you had simply just used torch..
-
Hot Functions: Functions that are called repeatedly with the same or similar shapes will benefit the most. The compilation time for
jax.jit
can sometimes be non-trivial, so you want to make sure the payoff justifies the upfront cost. -
Numerical Computation: If your function is doing heavy number crunching (matrix multiplications, tensor manipulations, etc.),
jax.jit
will likely provide a huge boost. Pure control flow functions, or those with lots of I/O, won’t benefit as much. -
Stable Shapes:
jax.jit
works best when the shapes of your arrays remain consistent between function calls. If the shape varies a lot, you might end up recompiling the function frequently, which negates the performance gains.
Example: Optimizing a Neural Network Layer
Let’s see how jax.jit
can turbocharge something more interesting: a fully connected layer in a neural network.
By simply adding jax.jit
, you’ll often see anywhere from a 5x to 10x speedup depending on the size of your data and the complexity of the function.
JIT + Grad = Magic
You can also use jax.jit
in combination with jax.grad
for computing gradients efficiently. Here’s how you would use jax.jit
to accelerate a loss function during training:
Combining jax.jit
with jax.grad
not only accelerates the forward pass (calculating the predictions), but also speeds up the backward pass (computing the gradients), very useful for training nns at scale.
Mistakes with jax.jit
While jax.jit
is extremely powerful, it’s not without its quirks. Here are a few common pitfalls to avoid:
-
Side Effects:
jax.jit
expects pure functions, meaning that any side effects (e.g., printing, modifying global variables) won’t work as expected. If your function has side effects, they’ll be silently ignored. -
Dynamic Shapes: As mentioned earlier,
jax.jit
is most effective when your input shapes are stable. If the shape of your arrays changes often, JAX will need to recompile the function, which can offset the performance gains. I spent a lot of hours trying to figure out why my jitted code wasn’t working, don’t do the same think. -
Python Control Flow: Python’s native control flow (e.g.,
if
,for
,while
) insidejax.jit
-compiled functions won’t always behave as expected. JAX uses its own control flow primitives likejax.lax.cond
andjax.lax.scan
to handle loops and conditionals efficiently. If you have complex control logic, make sure to use JAX’s alternatives.
Example: Control Flow in JIT
Conclusion: Embrace the Speed
This takes a while to get used to, especially having to make your arrays non-dynamic, but if you can get the hang of it the performance gains are well worth it. Spend a weekend or two playing with Jax.