天天看点

【MXNet】(七):自动求梯度

MXNet中可以使用autograd来自动求梯度。

以对函数 y=2x⊤x求关于列向量 x 的梯度为例。

首先导入模块。

from mxnet import autograd, nd
           

创建变量x,

x = nd.arange(4).reshape((4, 1))
x
           

输出,

【MXNet】(七):自动求梯度

然后调用

attach_grad

函数来申请存储梯度所需要的内存,

x.attach_grad()
           

为了减少计算和内存开销,默认条件下MXNet不会记录用于求梯度的计算。如果要求MXNet记录与求梯度有关的计算,需要调用

record

函数,

with autograd.record():
    y = 2 * nd.dot(x.T, x)
           

接下来调用

backward

函数自动求梯度。

y.backward()
           

如果

y

不是一个标量,MXNet将默认先对

y

中元素求和得到新的变量,再求该变量有关

x

的梯度。

下面来验证一下。函数y关于x的梯度是4x。

assert (x.grad - 4 * x).norm().asscalar() == 0
x.grad
           

输出,

【MXNet】(七):自动求梯度

继续阅读