JAX is a new competitor to TensorFlow and PyTorch. JAX emphasizes simplicity without sacrificing speed and scalability. Since JAX requires less boilerplate code, the program is shorter and closer to math, making it easier to understand.
Long story short:
- Use import jax.numpy to access the NumPy function and import jax.scipy to access the SciPy function.
- By decorating with @jax.jit, you can speed up on-the-fly compilation.
- Use jax.grad for derivation.
- Use jax.vmap for vectorization and jax.pmap for cross-device parallelization.
Functional programming
JAX follows a philosophy of functional programming. This means that your function must be independent or pure: no side effects are allowed. Essentially, a pure function looks like a mathematical function (Figure 1). There's input coming in, there's stuff coming out, but there's no communication with the outside world.
Example #1
The following code snippet is a non-functionally pure example.
import jax.numpy as jnp
bias = jnp.array(0)
def impure_example(x):
total = x + bias
return total
Note deviations outside of impure_example. During compilation (see below), bias may be cached so that changes in bias are no longer reflected.
Example #2
Here's a pure example.
def pure_example(x, weights, bias):
activation = weights @ x + bias
return activation
Here, the pure_example is independent: all parameters are passed as arguments.
Deterministic sampler
In computers, there is no such thing as true randomness. Conversely, libraries such as NumPy and TensorFlow keep track of pseudorandom number states to generate "random" samples.
A direct consequence of functional programming is that random functions work differently. Since global state is no longer allowed, a pseudo-random number generator (PRNG) key needs to be explicitly passed in each time a random number is sampled
import jax
key = jax.random.PRNGKey(42)
u = jax.random.uniform(key)
In addition, it is your responsibility to advance the "random state" for any subsequent calls.
key = jax.random.PRNGKey(43)
# Split off and consume subkey.
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey)
# Split off and consume second subkey.
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey)
..
jit
You can speed up your code by compiling JAX instructions on the fly. For example, to compile the Scale Exponential Linear Units (SELU) function, use the NumPy function in jax.numpy and add a jax.jit decorator to the function, as follows:
from jax import jit
@jit
def selu(x, α=1.67, λ=1.05):
return λ * jnp.where(x > 0, x, α * jnp.exp(x) - α)
JAX keeps track of your instructions and converts them into jaxpr. This enables the Accelerated Linear Algebra (XLA) compiler to generate very efficient optimized code for your accelerator.
Gard
One of the most powerful features of JAX is that you can easily get gard. With jax.grad, you can define a new function, the sign derivative.
from jax import grad
def f(x):
return x + 0.5 * x**2
df_dx = grad(f)
d2f_dx2 = grad(grad(f))
As you can see in the example, you are not limited to the first derivative. You can get the nth derivative by simply linking the grad function n times in order.
vmap and pmap
Matrix multiplication to get all batches sized correctly requires great care. JAX's vectorization mapping function, vmap, alleviates this burden by vectorizing the function. Basically, every block of code that applies the function f by element is a candidate to be replaced by vmap. Let's look at an example.
Calculate the linear function:
def linear(x):
return weights @ x
In a batch of examples [x₁, x2,..], we can naively (without vmap) implement it like this:
def naively_batched_linear(X_batched):
return jnp.stack([linear(x) for x in X_batched])
Conversely, by vectorizing linearity with VMAP, we can calculate the entire batch at once:
def vmap_batched_linear(X_batched):
return vmap(linear)(X_batched)