天天看點

JAX: 快如 PyTorch,簡單如 NumPy - 深度學習與資料科學

作者:冷凍工廠
JAX: 快如 PyTorch,簡單如 NumPy - 深度學習與資料科學

JAX 是 TensorFlow 和 PyTorch 的新競争對手。 JAX 強調簡單性而不犧牲速度和可擴充性。由于 JAX 需要更少的樣闆代碼,是以程式更短、更接近數學,是以更容易了解。

長話短說:

  • 使用 import jax.numpy 通路 NumPy 函數,使用 import jax.scipy 通路 SciPy 函數。
  • 通過使用 @jax.jit 進行裝飾,可以加快即時編譯速度。
  • 使用 jax.grad 求導。
  • 使用 jax.vmap 進行矢量化,并使用 jax.pmap 進行跨裝置并行化。

函數式程式設計

JAX: 快如 PyTorch,簡單如 NumPy - 深度學習與資料科學

JAX 遵循函數式程式設計哲學。這意味着您的函數必須是獨立的或純粹的:不允許有副作用。本質上,純函數看起來像數學函數(圖 1)。有輸入進來,有東西出來,但與外界沒有溝通。

例子#1

以下代碼片段是一個非功能純的示例。

import jax.numpy as jnp

bias = jnp.array(0)
def impure_example(x):
   total = x + bias
   return total
           

注意 impure_example 之外的偏差。在編譯期間(見下文),偏差可能會被緩存,是以不再反映偏差的變化。

例子#2

這是一個pure的例子。

def pure_example(x, weights, bias):
   activation = weights @ x + bias
   return activation
           

在這裡,pure_example 是獨立的:所有參數都作為參數傳遞。

确定性采樣器

JAX: 快如 PyTorch,簡單如 NumPy - 深度學習與資料科學

在計算機中,不存在真正的随機性。相反,NumPy 和 TensorFlow 等庫會跟蹤僞随機數狀态來生成“随機”樣本。

函數式程式設計的直接後果是随機函數的工作方式不同。由于不再允許全局狀态,是以每次采樣随機數時都需要顯式傳入僞随機數生成器 (PRNG) 密鑰

import jax

key = jax.random.PRNGKey(42)
u = jax.random.uniform(key)
           

此外,您有責任為任何後續調用推進“随機狀态”。

key = jax.random.PRNGKey(43)

# Split off and consume subkey.
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey)

# Split off and consume second subkey.
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey)

..
           

jit

您可以通過即時編譯 JAX 指令來加快代碼速度。例如,要編譯縮放指數線性機關 (SELU) 函數,請使用 jax.numpy 中的 NumPy 函數并将 jax.jit 裝飾器添加到該函數,如下所示:

from jax import jit

@jit
def selu(x, α=1.67, λ=1.05):
 return λ * jnp.where(x > 0, x, α * jnp.exp(x) - α)
           

JAX 會跟蹤您的指令并将其轉換為 jaxpr。這使得加速線性代數 (XLA) 編譯器能夠為您的加速器生成非常高效的優化代碼。

gard

JAX 最強大的功能之一是您可以輕松擷取 gard。使用 jax.grad,您可以定義一個新函數,即符号導數。

from jax import grad

def f(x):
   return x + 0.5 * x**2

df_dx = grad(f)
d2f_dx2 = grad(grad(f))
           

正如您在示例中看到的,您不僅限于一階導數。您可以通過簡單地按順序連結 grad 函數 n 次來擷取 n 階導數。

vmap 和 pmap

矩陣乘法使所有批次尺寸正确需要非常細心。 JAX 的矢量化映射函數 vmap 通過對函數進行矢量化來減輕這種負擔。基本上,每個按元素應用函數 f 的代碼塊都是由 vmap 替換的候選者。讓我們看一個例子。

計算線性函數:

def linear(x):
 return weights @ x
           

在一批示例 [x₁, x2,..] 中,我們可以天真地(沒有 vmap)實作它,如下所示:

def naively_batched_linear(X_batched):
 return jnp.stack([linear(x) for x in X_batched])
           

相反,通過使用 vmap 對線性進行向量化,我們可以一次性計算整個批次:

def vmap_batched_linear(X_batched):
 return vmap(linear)(X_batched)
           

繼續閱讀