天天看點

用SymPy簡化神經網絡的求導

神經網絡模型

這裡不重點介紹神經網絡模型,這裡有神經網絡比較簡潔的介紹和推導。[機器學習] Coursera ML筆記

SymPy(符号計算架構)的安裝

我的系統為Ubuntu 14

安裝比較簡單:sudo apt-get install python-sympy【全部小寫,csdn自動變成大寫了◔ ‸◔?】

求導

為了簡化叙述這裡不用求和符号,w,b,x均為矩陣形式。

在python終端輸入:

from sympy import *
 w=Symbol('w')
 b=Symbol('b')
 x=Symbol('x')
 y=Symbol('y')
 
 #或者用下面的方式生成符号
 #w,b,x,y=symbols('w b x y')
 
 a=w*x+b
 print a
           

這裡可以看到a的輸出 w ⋅ x + b w⋅x+b w⋅x+b

這裡假設激活函數是tanh

out=tanh(a)
 e=(y-out)**2
 print e
           

輸出: ( y − t a n h ( b + w ⋅ x ) ) 2 (y - tanh(b + w⋅x))^2 (y−tanh(b+w⋅x))2

這裡對e求w的微分

print diff(e,w)
           

輸出: − 2 ⋅ x ⋅ ( y − t a n h ( b + w ⋅ x ) ) ⋅ ( − t a n h 2 ( b + w ⋅ x ) + 1 ) -2⋅x⋅(y - tanh(b + w⋅x))⋅(- tanh ^2(b + w⋅x) + 1) −2⋅x⋅(y−tanh(b+w⋅x))⋅(−tanh2(b+w⋅x)+1)

對e求b的微分

print diff(e,b)
           

輸出: ( y − t a n h ( b + w ⋅ x ) ) ⋅ ( 2 ⋅ t a n h 2 ( b + w ⋅ x ) − 2 ) (y - tanh(b + w⋅x))⋅(2⋅tanh^2 (b + w⋅x) - 2) (y−tanh(b+w⋅x))⋅(2⋅tanh2(b+w⋅x)−2)

求導數是不是變得很簡單?

總結

當然還有很多其他機器學習算法,如線性回歸,邏輯回歸等等,可以通過sympy的形式求導和積分,或者去驗證自己的導數是否是正确的。這裡隻是一個入門級别的sympy使用介紹。詳細可以點選這裡有比較全面的tutorial,還有pdf文檔