laitimes

JAX: Fast as PyTorch and as simple as NumPy - Deep Learning Xi & Data Science

author:Refrigeration plant
JAX: Fast as PyTorch and as simple as NumPy - Deep Learning Xi & Data Science

JAX is a new competitor to TensorFlow and PyTorch. JAX emphasizes simplicity without sacrificing speed and scalability. Since JAX requires less boilerplate code, the program is shorter and closer to math, making it easier to understand.

Long story short:

  • Use import jax.numpy to access the NumPy function and import jax.scipy to access the SciPy function.
  • By decorating with @jax.jit, you can speed up on-the-fly compilation.
  • Use jax.grad for derivation.
  • Use jax.vmap for vectorization and jax.pmap for cross-device parallelization.

Functional programming

JAX: Fast as PyTorch and as simple as NumPy - Deep Learning Xi & Data Science

JAX follows a philosophy of functional programming. This means that your function must be independent or pure: no side effects are allowed. Essentially, a pure function looks like a mathematical function (Figure 1). There's input coming in, there's stuff coming out, but there's no communication with the outside world.

Example #1

The following code snippet is a non-functionally pure example.

import jax.numpy as jnp

bias = jnp.array(0)
def impure_example(x):
   total = x + bias
   return total
           

Note deviations outside of impure_example. During compilation (see below), bias may be cached so that changes in bias are no longer reflected.

Example #2

Here's a pure example.

def pure_example(x, weights, bias):
   activation = weights @ x + bias
   return activation
           

Here, the pure_example is independent: all parameters are passed as arguments.

Deterministic sampler

JAX: Fast as PyTorch and as simple as NumPy - Deep Learning Xi & Data Science

In computers, there is no such thing as true randomness. Conversely, libraries such as NumPy and TensorFlow keep track of pseudorandom number states to generate "random" samples.

A direct consequence of functional programming is that random functions work differently. Since global state is no longer allowed, a pseudo-random number generator (PRNG) key needs to be explicitly passed in each time a random number is sampled

import jax

key = jax.random.PRNGKey(42)
u = jax.random.uniform(key)
           

In addition, it is your responsibility to advance the "random state" for any subsequent calls.

key = jax.random.PRNGKey(43)

# Split off and consume subkey.
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey)

# Split off and consume second subkey.
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey)

..
           

jit

You can speed up your code by compiling JAX instructions on the fly. For example, to compile the Scale Exponential Linear Units (SELU) function, use the NumPy function in jax.numpy and add a jax.jit decorator to the function, as follows:

from jax import jit

@jit
def selu(x, α=1.67, λ=1.05):
 return λ * jnp.where(x > 0, x, α * jnp.exp(x) - α)
           

JAX keeps track of your instructions and converts them into jaxpr. This enables the Accelerated Linear Algebra (XLA) compiler to generate very efficient optimized code for your accelerator.

Gard

One of the most powerful features of JAX is that you can easily get gard. With jax.grad, you can define a new function, the sign derivative.

from jax import grad

def f(x):
   return x + 0.5 * x**2

df_dx = grad(f)
d2f_dx2 = grad(grad(f))
           

As you can see in the example, you are not limited to the first derivative. You can get the nth derivative by simply linking the grad function n times in order.

vmap and pmap

Matrix multiplication to get all batches sized correctly requires great care. JAX's vectorization mapping function, vmap, alleviates this burden by vectorizing the function. Basically, every block of code that applies the function f by element is a candidate to be replaced by vmap. Let's look at an example.

Calculate the linear function:

def linear(x):
 return weights @ x
           

In a batch of examples [x₁, x2,..], we can naively (without vmap) implement it like this:

def naively_batched_linear(X_batched):
 return jnp.stack([linear(x) for x in X_batched])
           

Conversely, by vectorizing linearity with VMAP, we can calculate the entire batch at once:

def vmap_batched_linear(X_batched):
 return vmap(linear)(X_batched)
           

Read on