JAX 是 TensorFlow 和 PyTorch 的新競争對手。 JAX 強調簡單性而不犧牲速度和可擴充性。由于 JAX 需要更少的樣闆代碼,是以程式更短、更接近數學,是以更容易了解。
長話短說:
- 使用 import jax.numpy 通路 NumPy 函數,使用 import jax.scipy 通路 SciPy 函數。
- 通過使用 @jax.jit 進行裝飾,可以加快即時編譯速度。
- 使用 jax.grad 求導。
- 使用 jax.vmap 進行矢量化,并使用 jax.pmap 進行跨裝置并行化。
函數式程式設計
JAX 遵循函數式程式設計哲學。這意味着您的函數必須是獨立的或純粹的:不允許有副作用。本質上,純函數看起來像數學函數(圖 1)。有輸入進來,有東西出來,但與外界沒有溝通。
例子#1
以下代碼片段是一個非功能純的示例。
import jax.numpy as jnp
bias = jnp.array(0)
def impure_example(x):
total = x + bias
return total
注意 impure_example 之外的偏差。在編譯期間(見下文),偏差可能會被緩存,是以不再反映偏差的變化。
例子#2
這是一個pure的例子。
def pure_example(x, weights, bias):
activation = weights @ x + bias
return activation
在這裡,pure_example 是獨立的:所有參數都作為參數傳遞。
确定性采樣器
在計算機中,不存在真正的随機性。相反,NumPy 和 TensorFlow 等庫會跟蹤僞随機數狀态來生成“随機”樣本。
函數式程式設計的直接後果是随機函數的工作方式不同。由于不再允許全局狀态,是以每次采樣随機數時都需要顯式傳入僞随機數生成器 (PRNG) 密鑰
import jax
key = jax.random.PRNGKey(42)
u = jax.random.uniform(key)
此外,您有責任為任何後續調用推進“随機狀态”。
key = jax.random.PRNGKey(43)
# Split off and consume subkey.
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey)
# Split off and consume second subkey.
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey)
..
jit
您可以通過即時編譯 JAX 指令來加快代碼速度。例如,要編譯縮放指數線性機關 (SELU) 函數,請使用 jax.numpy 中的 NumPy 函數并将 jax.jit 裝飾器添加到該函數,如下所示:
from jax import jit
@jit
def selu(x, α=1.67, λ=1.05):
return λ * jnp.where(x > 0, x, α * jnp.exp(x) - α)
JAX 會跟蹤您的指令并将其轉換為 jaxpr。這使得加速線性代數 (XLA) 編譯器能夠為您的加速器生成非常高效的優化代碼。
gard
JAX 最強大的功能之一是您可以輕松擷取 gard。使用 jax.grad,您可以定義一個新函數,即符号導數。
from jax import grad
def f(x):
return x + 0.5 * x**2
df_dx = grad(f)
d2f_dx2 = grad(grad(f))
正如您在示例中看到的,您不僅限于一階導數。您可以通過簡單地按順序連結 grad 函數 n 次來擷取 n 階導數。
vmap 和 pmap
矩陣乘法使所有批次尺寸正确需要非常細心。 JAX 的矢量化映射函數 vmap 通過對函數進行矢量化來減輕這種負擔。基本上,每個按元素應用函數 f 的代碼塊都是由 vmap 替換的候選者。讓我們看一個例子。
計算線性函數:
def linear(x):
return weights @ x
在一批示例 [x₁, x2,..] 中,我們可以天真地(沒有 vmap)實作它,如下所示:
def naively_batched_linear(X_batched):
return jnp.stack([linear(x) for x in X_batched])
相反,通過使用 vmap 對線性進行向量化,我們可以一次性計算整個批次:
def vmap_batched_linear(X_batched):
return vmap(linear)(X_batched)