代碼位址: https://github.com/vic-w/torch-practice/tree/master/multilayer-perceptron
上一次我們使用了輸出節點和輸入節點直接相連的網絡。網絡裡隻有兩個可變參數。這種網絡隻能表示一條直線,不能适應複雜的曲線。我們将把它改造為一個多層網絡。一個輸入節點,然後是兩個隐藏層,每個隐藏層有3個節點,每個隐藏節點後面都跟一個非線性的Sigmoid函數。如圖所示,這次我們使用的網絡是有2個隐藏層,每層有3個節點的多層神經網絡。
![](https://img.laitimes.com/img/_0nNw4CM6IyYiwiM6ICdiwiIyVGduV2QvwVe0lmdhJ3ZvwFM38CXlZHbvN3cpR2Lc1TPB10QGtWUCpEMJ9CXsxWam9CXwADNvwVZ6l2c052bm9CXUJDT1wkNhVzLcRnbvZ2LcZXUYpVd1kmYr50MZV3YyI2cKJDT29GRjBjUIF2LcRHelR3LcJzLctmch1mclRXY39DO5kTMxQjMyEjNwETM1EDMy8CX0Vmbu4GZzNmLn9Gbi1yZtl2Lc9CX6MHc0RHaiojIsJye.jpg)
那麼這樣的結構用代碼如何表示呢?我們來直接在上一次的代碼上修改。 這個網絡結構是一層一層的疊加起來的,nn庫裡有一個類型叫做Sequential序列,正好适合我們。這個Sequential是一個容器類,我們可以在它裡面添加一些基本的子產品。
model = nn.Sequential()
第一個我們要添加的是輸入節點和第一個隐藏層的連接配接,它是一個Linear線性的類型,它的輸入是1個節點,輸出是3個節點。
model:add(nn.Linear(1,3))
然後我們在他後面添加一個Sigmoid層,它的節點個數會自動和前一層的輸出個數保持一緻。
model:add(nn.Sigmoid())
接下來我們添加第一和第二隐藏層中間的線性連接配接,輸入是3,輸出也是3。
model:add(nn.Linear(3,3))
再添加一個Sigmoid層。
model:add(nn.Sigmoid())
最後是第二隐藏層和輸出節點之間的線性連接配接,輸入是3,輸出是1。
model:add(nn.Linear(3,1))
是以完整的建立模型的代碼看起來是這樣的
model = nn.Sequential()
model:add(nn.Linear(1,3))
model:add(nn.Sigmoid())
model:add(nn.Linear(3,3))
model:add(nn.Sigmoid())
model:add(nn.Linear(3,1))
好,理論上講我們已經改造完了網絡,可以開始訓練了。我們運作一下,看一下結果。我們會很意外的發現這個結果還不如我們上一次的結果。
其實這裡面存在兩個問題: 一個是我們的訓練資料,輸入的月份取值範圍從1到10,輸出的價格取值範圍是幾萬。這樣開始訓練的時候後面幾層的梯度會受到輸出值的影響,變得非常大,迅速的把前面幾層的參數推到一個很大的數值。而Sigmoid函數在遠離零點的位置幾乎梯度為零,是以就一直固定在一個位置不動了。
解決的方法是把輸入和輸出的取值範圍調整到合适的區間,我這裡把輸入除以10,輸出除以50000。預測時再把50000乘回去。在代碼裡面展現,就是在開頭和結尾加兩個輔助層,nn.MulConstant,這種類型的子產品是對網絡中的每個元素乘上一個常數。在輸入進入之前先乘以0.1,在輸入之後乘以50000。 這樣一來,建立模型的代碼就變成了這樣:
model = nn.Sequential()
model:add(nn.MulConstant(0.1)) --在輸入進入之前先乘以0.1
model:add(nn.Linear(1,3))
model:add(nn.Sigmoid())
model:add(nn.Linear(3,3))
model:add(nn.Sigmoid())
model:add(nn.Linear(3,1))
model:add(nn.MulConstant(50000)) --在輸入之後乘以50000
資料預處理問題現在解決了,還有一個問題是訓練的速度很慢。因為我們現在的優化方法用的是最原始梯度下降法。 其實Torch已經給我們提供了各種先進的優化算法,都放在optim這個庫裡。我們在檔案的頭部添加包含optim庫:
require 'optim'
另外,還需要把model裡面的參數找出來友善随時調用。
w, dl_dw = model:getParameters()
w是model裡面所有可調參數的集合,dl_dw是每個參數對loss的偏導數。需要注意的是這裡的w和dl_dw都相當于C++裡面的“引用”,一旦你對他們進行了操作,模型裡的參數也會跟着改變。
優化函數的調用方法有一點特殊,需要你先提供一個目标函數,這個函數相當于C++裡的回調函數,他的輸入是一組網絡權重參數w,輸出有兩個,第一個是網絡使用參數w時,其輸出結果與實際結果之間的差别,也可以叫loss損失,另一個是w中每個參數對于loss的偏導數。
feval = function(w_new)
if w ~= w_new then w:copy(w_new) end
dl_dw:zero()
price_predict = model:forward(month_train)
loss = criterion:forward(price_predict, price_train)
model:backward(month_train, criterion:backward(price_predict, price_train))
return loss, dl_dw
end
這個回調函數可以參照這個例子來寫,同樣是“例行公事”,調用一下反向傳播的算法。
有了這個目标函數,優化疊代的過程就簡單多了。隻需要一句optim.rprop(feval, w, params)。 rprop是一種改進的梯度下降法,它隻看梯度的方向,不管大小,隻要方向不變,它會無限的增大步長,是以他速度非常快。疊代的代碼如下:
params = {
learningRate = 1e-2
}
for i=1,3000 do
optim.rprop(feval, w, params)
if i%10==0 then
gnuplot.plot({month, price}, {month_train:reshape(10), price_predict:reshape(10)})
end
end
其中每10次疊代會把結果用gnuplot畫出來。
我們來運作一下。 在指令行鍵入
th mlp.lua
看一下結果,這次的結果看起來就好多了。綠線(預測值)幾乎和藍線(實際值)重合在一起了。
下一節,我們将介紹如何用卷積神經網絡識别MNIST手寫數字圖像。
本節的完整代碼:
require 'torch'
require 'nn'
require 'optim'
require 'gnuplot'
month = torch.range(1,10)
price = torch.Tensor{28993,29110,29436,30791,33384,36762,39900,39972,40230,40146}
model = nn.Sequential()
model:add(nn.MulConstant(0.1))
model:add(nn.Linear(1,3))
model:add(nn.Sigmoid())
model:add(nn.Linear(3,3))
model:add(nn.Sigmoid())
model:add(nn.Linear(3,1))
model:add(nn.MulConstant(50000))
criterion = nn.MSECriterion()
month_train = month:reshape(10,1)
price_train = price:reshape(10,1)
gnuplot.figure()
w, dl_dw = model:getParameters()
feval = function(w_new)
if w ~= w_new then w:copy(w_new) end
dl_dw:zero()
price_predict = model:forward(month_train)
loss = criterion:forward(price_predict, price_train)
model:backward(month_train, criterion:backward(price_predict, price_train))
return loss, dl_dw
end
params = {
learningRate = 1e-2
}
for i=1,3000 do
optim.rprop(feval, w, params)
if i%10==0 then
gnuplot.plot({month, price}, {month_train:reshape(10), price_predict:reshape(10)})
end
end
month_predict = torch.range(1,12)
local price_predict = model:forward(month_predict:reshape(12,1))
print(price_predict)
gnuplot.pngfigure('plot.png')
gnuplot.plot({month, price}, {month_predict, price_predict})
gnuplot.plotflush()