天天看點

【MXNet】(七):自動求梯度

MXNet中可以使用autograd來自動求梯度。

以對函數 y=2x⊤x求關于列向量 x 的梯度為例。

首先導入子產品。

from mxnet import autograd, nd
           

建立變量x,

x = nd.arange(4).reshape((4, 1))
x
           

輸出,

【MXNet】(七):自動求梯度

然後調用

attach_grad

函數來申請存儲梯度所需要的記憶體,

x.attach_grad()
           

為了減少計算和記憶體開銷,預設條件下MXNet不會記錄用于求梯度的計算。如果要求MXNet記錄與求梯度有關的計算,需要調用

record

函數,

with autograd.record():
    y = 2 * nd.dot(x.T, x)
           

接下來調用

backward

函數自動求梯度。

y.backward()
           

如果

y

不是一個标量,MXNet将預設先對

y

中元素求和得到新的變量,再求該變量有關

x

的梯度。

下面來驗證一下。函數y關于x的梯度是4x。

assert (x.grad - 4 * x).norm().asscalar() == 0
x.grad
           

輸出,

【MXNet】(七):自動求梯度

繼續閱讀