NDArray可以很友善的求解導數,比如下面的例子:(代碼主要參考自https://zh.gluon.ai/chapter_crashcourse/autograd.html)
用代碼實作如下:
1 import mxnet.ndarray as nd
2 import mxnet.autograd as ag
3 x = nd.array([[1,2],[3,4]])
4 print(x)
5 x.attach_grad() #附加導數存放的空間
6 with ag.record():
7 y = 2*x**2
8 y.backward() #求導
9 z = x.grad #将導數結果(也是一個矩陣)指派給z
10 print(z) #列印結果
[[ 1. 2.]
[ 3. 4.]]
<NDArray 2x2 @cpu(0)>
[[ 4. 8.]
[ 12. 16.]]
<NDArray 2x2 @cpu(0)>
對控制流求導
NDArray還能對諸如if的控制分支進行求導,比如下面這段代碼:
1 def f(a):
2 if nd.sum(a).asscalar()<15: #如果矩陣a的元數和<15
3 b = a*2 #則所有元素*2
4 else:
5 b = a
6 return b
數學公式等價于:
這樣就轉換成本文最開頭示例一樣,變成單一函數求導,顯然導數值就是x前的常數項,驗證一下:
import mxnet.ndarray as nd
import mxnet.autograd as ag
def f(a):
if nd.sum(a).asscalar()<15: #如果矩陣a的元數和<15
b = a*2 #則所有元素平方
else:
b = a
return b
#注:1+2+3+4<15,是以進入b=a*2的分支
x = nd.array([[1,2],[3,4]])
print("x1=")
print(x)
x.attach_grad()
with ag.record():
y = f(x)
print("y1=")
print(y)
y.backward() #dy/dx = y/x 即:2
print("x1.grad=")
print(x.grad)
x = x*2
print("x2=")
print(x)
x.attach_grad()
with ag.record():
y = f(x)
print("y2=")
print(y)
y.backward()
print("x2.grad=")
print(x.grad)
x1=
[[ 1. 2.]
[ 3. 4.]]
<NDArray 2x2 @cpu(0)>
y1=
[[ 2. 4.]
[ 6. 8.]]
<NDArray 2x2 @cpu(0)>
x1.grad=
[[ 2. 2.]
[ 2. 2.]]
<NDArray 2x2 @cpu(0)>
x2=
[[ 2. 4.]
[ 6. 8.]]
<NDArray 2x2 @cpu(0)>
y2=
[[ 2. 4.]
[ 6. 8.]]
<NDArray 2x2 @cpu(0)>
x2.grad=
[[ 1. 1.]
[ 1. 1.]]
<NDArray 2x2 @cpu(0)>
頭梯度
原文上講得很含糊,其實所謂頭梯度,就是一個求導結果前的乘法系數,見下面代碼:
1 import mxnet.ndarray as nd
2 import mxnet.autograd as ag
3
4 x = nd.array([[1,2],[3,4]])
5 print("x=")
6 print(x)
7
8 x.attach_grad()
9 with ag.record():
10 y = 2*x*x
11
12 head = nd.array([[10, 1.], [.1, .01]]) #所謂的"頭梯度"
13 print("head=")
14 print(head)
15 y.backward(head_gradient) #用頭梯度求導
16
17 print("x.grad=")
18 print(x.grad) #列印結果
x=
[[ 1. 2.]
[ 3. 4.]]
<NDArray 2x2 @cpu(0)>
head=
[[ 10. 1. ]
[ 0.1 0.01]]
<NDArray 2x2 @cpu(0)>
x.grad=
[[ 40. 8. ]
[ 1.20000005 0.16 ]]
<NDArray 2x2 @cpu(0)>
對比本文最開頭的求導結果,上面的代碼僅僅多了一個head矩陣,最終的結果,其實就是在正常求導結果的基礎上,再乘上head矩陣(指:數乘而非叉乘)
鍊式法則
先複習下數學
注:最後一行中所有變量x,y,z都是向量(即:矩形),為了不讓公式看上去很淩亂,就統一省掉了變量上的箭頭。NDArray對複合函數求導時,已經自動應用了鍊式法則,見下面的示例代碼:
1 import mxnet.ndarray as nd
2 import mxnet.autograd as ag
3
4 x = nd.array([[1,2],[3,4]])
5 print("x=")
6 print(x)
7
8 x.attach_grad()
9 with ag.record():
10 y = x**2
11 z = y**2 + y
12
13 z.backward()
14
15 print("x.grad=")
16 print(x.grad) #列印結果
17
18 print("w=")
19 w = 4*x**3 + 2*x
20 print(w) # 驗證結果
x=
[[ 1. 2.]
[ 3. 4.]]
<NDArray 2x2 @cpu(0)>
x.grad=
[[ 6. 36.]
[ 114. 264.]]
<NDArray 2x2 @cpu(0)>
w=
[[ 6. 36.]
[ 114. 264.]]
<NDArray 2x2 @cpu(0)>
作者:菩提樹下的楊過
出處:http://yjmyzz.cnblogs.com
本文版權歸作者和部落格園共有,歡迎轉載,但未經作者同意必須保留此段聲明,且在文章頁面明顯位置給出原文連接配接,否則保留追究法律責任的權利。