MXNet中可以使用autograd來自動求梯度。
以對函數 y=2x⊤x求關于列向量 x 的梯度為例。
首先導入子產品。
from mxnet import autograd, nd
建立變量x,
x = nd.arange(4).reshape((4, 1))
x
輸出,
然後調用
attach_grad
函數來申請存儲梯度所需要的記憶體,
x.attach_grad()
為了減少計算和記憶體開銷,預設條件下MXNet不會記錄用于求梯度的計算。如果要求MXNet記錄與求梯度有關的計算,需要調用
record
函數,
with autograd.record():
y = 2 * nd.dot(x.T, x)
接下來調用
backward
函數自動求梯度。
y.backward()
如果
y
不是一個标量,MXNet将預設先對
y
中元素求和得到新的變量,再求該變量有關
x
的梯度。
下面來驗證一下。函數y關于x的梯度是4x。
assert (x.grad - 4 * x).norm().asscalar() == 0
x.grad
輸出,