天天看點

【增強學習】Torch中的增強學習層理念監督學習實驗

要想在Torch架構下解決計算機視覺中的增強學習問題(例如Visual Attention),可以使用Nicholas Leonard提供的dpnn包。這個包對Torch中原有nn包進行了強大的擴充,在Torch官方的教程中予以推薦。

增強學習是一個非常廣闊的領域,本文隻從了解代碼的角度進行最簡單說明。

理念

機器學習可以分為三支:

  • 無監督學習(unsupervised):沒有教師信号。例:給定一系列圖檔,學習網絡參數,能夠按照訓練集的機率,生成新的圖檔。聚類也是一種無監督學習。
  • 監督學習(supervised):有教師信号,教導正确的動作。例:手把手教你下棋:這種局面下這兒,那種局面下那兒。常見的圖像識别問題也是監督學習。
  • 增強學習(reinforce):有教師信号,但不知道正确的動作是什麼,隻給出動作的回報。例:其實我也太不會下棋,但是你下完我知道數子兒判斷輸赢。

對于同一個問題,可以用不同的觀點來看待。以手寫數字識别任務為例:

監督學習觀點下,把分類結果看成動作,教師給出的是标定的分類結果。

增強學習觀點下,把每次觀察位置看成動作1,教師給出的回報是分類正确與否。

監督學習

作為複習,先來看看監督學習網絡的結構:

【增強學習】Torch中的增強學習層理念監督學習實驗

網絡輸入為 x x x,輸出為 y y y。輸出由網絡參數 θ \theta θ控制:

y = f ( x ; θ ) y=f(x;\theta) y=f(x;θ)

訓練時,需要在後面接一個**準則(Criterion)**子產品充當教師,給出一個代價函數 E E E。

準則子產品是解析的,可以求出代價對輸出的導數:

g r a d O u t = ∂ E ∂ y gradOut=\frac{\partial E}{\partial y} gradOut=∂y∂E​

根據鍊式法則,進一步可以求出代價對輸入的導數:

g r a d I n = g r a d O u t ⋅ ∂ y ∂ x = ∂ E ∂ x gradIn = gradOut \cdot \frac{\partial y}{\partial x} = \frac{\partial E}{\partial x} gradIn=gradOut⋅∂x∂y​=∂x∂E​

網絡子產品可以串接:

g r a d I n n + 1 = g r a d I n n ⋅ ∂ y n ∂ x n {gradIn}_{n+1}={gradIn}_n \cdot \frac{\partial y_n}{\partial x_n} gradInn+1​=gradInn​⋅∂xn​∂yn​​

訓練時,網絡參數更新:

− Δ θ n = ∂ E ∂ θ n = g r a d I n n + 1 ⋅ ∂ y n ∂ θ n -\Delta \theta_n = \frac{\partial E}{\partial \theta_n}={gradIn}_{n+1} \cdot \frac{\partial y_n}{\partial \theta_n} −Δθn​=∂θn​∂E​=gradInn+1​⋅∂θn​∂yn​​

最關鍵的部分是:求出每一層的gradIn,之後傳遞給前一層。

增強學習

增強學習網絡具有如下結構:

【增強學習】Torch中的增強學習層理念監督學習實驗

乍一看差不多,但有兩處不同。

第一,增強學習的層是統計的(stochastic),而不是确定的。也就是說, y y y是個随機變量,服從圍繞 x x x的某種分布。

y ∼ f ( x , y ; θ ) y\sim f(x,y;\theta) y∼f(x,y;θ)

舉例: f ( x , y ; θ ) = N ( y ; θ T x , σ ) f(x,y;\theta)=N(y;\theta^Tx,\sigma) f(x,y;θ)=N(y;θTx,σ), σ \sigma σ為預設參數。

實體意義

在訓練時,随機采樣層圍繞着現有政策,生成一些探索政策。在測試時,可以把這一層變成确定的。

訓練時,需要在後面接一個回報(Reward)子產品充當教師,回報是規則式的,對于輸入往往不可導,這是第二點不同。是以增強學習中沒有gradOutput這項。

回報對輸入的導數如下求解:

g r a d I n = ∂ R ∂ x = R ⋅ ∂ ln ⁡ f ( x , y ; θ ) ∂ x gradIn = \frac{\partial R}{\partial x}=R \cdot \frac{\partial \ln f(x,y;\theta)}{\partial x} gradIn=∂x∂R​=R⋅∂x∂lnf(x,y;θ)​

實體意義

∂ ln ⁡ f / ∂ x \partial \ln f / \partial x ∂lnf/∂x - 找到機率相對于輸入的變化方向。 x x x沿着這個方向變化,則生成目前動作 y y y的機率增大得最快。

R R R - 用回報來給上述變化權重。回報為正,說明動作 y y y選擇得當,則 x x x正向變化,提升輸出 y y y的機率;回報為負, x x x向反方向變化,修改目前政策。

在實作時,為了更清晰簡潔,把增強學習層中的參數部分抽出來,變成一個監督學習層:

x ˉ = θ T x \bar x = \theta^Tx xˉ=θTx

y ∼ N ( y ; x ˉ , σ ) y \sim N(y;\bar x, \sigma) y∼N(y;xˉ,σ)

即:增強學習層本身不包含參數,在測試時正向生成采樣,在訓練時反向生成gradInput。

Variance Reduction

對上式中的權重做修改:

g r a d I n = ∂ R ∂ x = ( R − b ) ⋅ ∂ ln ⁡ f ( x , y ) ∂ x gradIn = \frac{\partial R}{\partial x}=(R-b) \cdot \frac{\partial \ln f(x,y)}{\partial x} gradIn=∂x∂R​=(R−b)⋅∂x∂lnf(x,y)​

其中 b b b稱為baseline,用來降低上式的絕對值,稱為variance reduction,簡稱VR。

b b b是一個随着訓練而變化的參數,設定為回報的期望: b = E [ R ] b=E\left [R\right] b=E[R]。

舉例:分類正确 R = 1 R=1 R=1,分類錯誤 R = 0 R=0 R=0。目前網絡能正确分類20%樣本,則 b = 0.2 b=0.2 b=0.2。權重範圍從 [ 0 , 1 ] [0,1] [0,1]變成 [ − 0.2 , 0.8 ] [-0.2,0.8] [−0.2,0.8]。

實體意義

VR實際上對回報進行了歸一化:比“目前政策”好的政策,才值得表彰,給予正回報。

混合網絡

實際應用中的網絡,往往混合了監督學習層(Sup)和增強學習層(Rei)。

【增強學習】Torch中的增強學習層理念監督學習實驗

訓練時,回報直接傳播給增強學習層(紅線)。增強學習層之後的梯度(gradInput3, gradOutput)都是0。

【增強學習】Torch中的增強學習層理念監督學習實驗

為了避免後續子產品沒法優化,實際中往往還需要一個監督學習中可導的Criterion子產品(綠線)。

【增強學習】Torch中的增強學習層理念監督學習實驗

Criterion子產品帶來的梯度傳播到增強學習子產品就停止了。換句話說:增強學習層出口處的gradOut是被忽略的。

實驗

構造網絡

我們在Torch中構造一個簡單的網絡來研究增強學習層。這個網絡參數已經設定好了,沒有任何實際意義。

require 'dpnn'

-- ------ CREATE NET --------
net = nn.Sequential()
-- [1]: fully connect
layer_fc1 = nn.Linear(3,2)
weight = torch.DoubleTensor(2,3)
for h=1,2 do
    for w=1,3 do
        weight[h][w] = (h-1)*3+w
    end
end
layer_fc1.weight = weight
layer_fc1.bias:fill(0)
net:add(layer_fc1)

-- [2]: reinforce
net:add(nn.ReinforceNormal(0.2,false))    -- 0.2: gaussian sampling variance; false: do not sample when testing

-- [3]: fully connect
layer_fc3 = nn.Linear(2,2)
layer_fc3.weight[1][1] = 1
layer_fc3.weight[1][2] = 0
layer_fc3.weight[2][1] = 0
layer_fc3.weight[2][2] = 1
layer_fc3.bias:fill(0)
net:add(layer_fc3)
           

網絡結構是這樣:

【增強學習】Torch中的增強學習層理念監督學習實驗

可以用如下語句檢視網絡結構:

for i=1,#net do
    print(net:get(i))
end
           

前向傳播

構造一個輸入: [ 1 , 1 , 1 ] [1,1,1] [1,1,1],送入網絡:

x = torch.Tensor(3):fill(1)       -- don't make Tensor(1,3): ReinforceNormal will give trouble
y = net:forward(x)
print(x)
for i=1,3 do
    print(net:get(i).output)
end
           

如果不考慮增強學習層的高斯采樣,網絡應該輸出 [ 6 , 15 ] [6,15] [6,15],加上随機采樣之後,輸出圍繞這個值波動。

接上增強學習子產品

如果網絡中包含ReinforceXXX層,必須使用帶有回報的教師子產品。我們把網絡輸出y看做針對兩類的打分,使用VRClassReward子產品。

注意:建立增強學習評估時,必須指明連接配接的網絡,因為要将回報廣播給網絡中的所有增強學習層。

VRClassReward的輸入為各類得分,如果最高得分類和标定相同,則回報+1,否則回報-1。除了y,VRClassReward還需要兩個資料:一個标量baseline,一個标量ground_truth指出正确的類标。

baseline = torch.Tensor(1):fill(0.5)
ground_truth = torch.Tensor(1):fill(2)
           

評估層輸入包含分類得分y和baseline,用前向傳播獲得損失。由于y[2]>y[1],分類為2結果正确,獲得的reward是1,loss為-1:

eval子產品後向傳播獲得gradOut

gradient_eval是個長度為2的sequence。gradient_eval[1]為針對輸出y的梯度,gradient_eval[2]為baseline的期望。

我們暫時不更新baseline,隻更新網絡參數,可以得到出口處的梯度:

經過反向傳播之後,可以看到回報 R − b R-b R−b被存儲在增強學習層中reward變量中。

更多資訊,可以參看ReinforceNormal.lua中的updateOutput和updateGradInput函數。

接上監督學習子產品

如果隻使用一個監督學習層作為教師子產品,ReinforceNormal層的reward沒有指派,無法執行反向傳播。

使用ParallelCriterion層,把增強學習和監督學習的評估函數聯合起來。

eval_reward = nn.VRClassReward(net)
eval_criterion = nn.MSECriterion()   --mean square error
eval = nn.ParallelCriterion():add(eval_reward):add(eval_criterion)
           

監督學習的真值和輸出y尺寸相同。兩個真值組合成gt:

gt_reward = torch.Tensor(1):fill(2)
gt_criterion = torch.Tensor(2)
gt_criterion[1] = 9
gt_criterion[2] = 10
gt = {gt_reward, gt_criterion}
           

執行前向傳播時,分别設定給增強學習和監督學習的輸入:

baseline = torch.Tensor(1):fill(0.5) 
input = {{y,baseline},y}    -- two input
loss = eval:forward(input,gt)
           

反向傳播的結果和輸入結構類似

gradient_eval結構

th> {
..>   1 :
..>     {
..>       1 : DoubleTensor - size: 2
..>       2 : DoubleTensor - size: 1
..>     }
..>   2 : DoubleTensor - size: 2
..> }
..>
           

和網絡輸出y相關的是gradient_eval[1][1]和gradient_eval[2],兩者相加,作為gradOut:

gradient_eval_hybrid =gradient_eval[1][1] + gradient_eval[2] 
gradient_net = net:backward(x, gradient_eval_hybrid)
           
  1. 這裡使用了visual attention的概念。不是一次看整張圖進行判斷。而是每次看一小部分,邊看邊移動,多次觀察進行判斷。 ↩︎

繼續閱讀