天天看點

【PyTorch】PyTorch中的梯度累加

PyTorch中的梯度累加

使用PyTorch實作梯度累加變相擴大batch

PyTorch中在反向傳播前為什麼要手動将梯度清零? - Pascal的回答 - 知乎

https://www.zhihu.com/question/303070254/answer/573037166

這種模式可以讓梯度玩出更多花樣,比如說梯度累加(gradient accumulation)

傳統的訓練函數,一個batch是這麼訓練的:

for i,(images,target) in enumerate(train_loader):
    # 1. input output
    images = images.cuda(non_blocking=True)
    target = torch.from_numpy(np.array(target)).float().cuda(non_blocking=True)
    outputs = model(images)
    loss = criterion(outputs,target)

    # 2. backward
    optimizer.zero_grad()   # reset gradient
    loss.backward()
    optimizer.step()
           
  1. 擷取loss:輸入圖像和标簽,通過infer計算得到預測值,計算損失函數;
  2. optimizer.zero_grad()

    清空過往梯度;
  3. loss.backward()

    反向傳播,計算目前梯度;
  4. optimizer.step()

    根據梯度更新網絡參數

簡單的說就是進來一個batch的資料,計算一次梯度,更新一次網絡

使用梯度累加是這麼寫的:

for i,(images,target) in enumerate(train_loader):
    # 1. input output
    images = images.cuda(non_blocking=True)
    target = torch.from_numpy(np.array(target)).float().cuda(non_blocking=True)
    outputs = model(images)
    loss = criterion(outputs,target)

    # 2.1 loss regularization
    loss = loss/accumulation_steps
    # 2.2 back propagation
    loss.backward()

    # 3. update parameters of net
    if((i+1)%accumulation_steps)==0:
        # optimizer the net
        optimizer.step()        # update parameters of net
        optimizer.zero_grad()   # reset gradient
           
  1. loss.backward()

     反向傳播,計算目前梯度;
  2. 多次循環步驟1-2,不清空梯度,使梯度累加在已有梯度上;
  3. 梯度累加了一定次數後,先

    optimizer.step()

     根據累計的梯度更新網絡參數,然後

    optimizer.zero_grad()

     清空過往梯度,為下一波梯度累加做準備;

總結來說:梯度累加就是,每次擷取1個batch的資料,計算1次梯度,梯度不清空,不斷累加,累加一定次數後,根據累加的梯度更新網絡參數,然後清空梯度,進行下一次循環。

一定條件下,batchsize越大訓練效果越好,梯度累加則實作了batchsize的變相擴大,如果

accumulation_steps

 為8,則batchsize '變相' 擴大了8倍,是我們這種乞丐實驗室解決顯存受限的一個不錯的trick,使用時需要注意,學習率也要适當放大。

更新1:關于BN是否有影響,之前有人是這麼說的:

As far as I know, batch norm statistics get updated on each forward pass, so no problem if you don't do

.backward()

 every time.

BN的估算是在forward階段就已經完成的,并不沖突,隻是

accumulation_steps=8

 和真實的batchsize放大八倍相比,效果自然是差一些,畢竟八倍Batchsize的BN估算出來的均值和方差肯定更精準一些。

更新2:根據李韶華的分享,可以适當調低BN自己的momentum參數:

bn自己有個momentum參數:x_new_running = (1 - momentum) * x_running + momentum * x_new_observed. momentum越接近0,老的running stats記得越久,是以可以得到更長序列的統計資訊

我簡單看了下PyTorch 1.0的源碼:https://github.com/pytorch/pytorch/blob/162ad945902e8fc9420cbd0ed432252bd7de673a/torch/nn/modules/batchnorm.py#L24,BN類裡面momentum這個屬性預設為0.1,可以嘗試調節下。

借助梯度累加,避免同時計算多個損失時存儲多個計算圖

PyTorch中在反向傳播前為什麼要手動将梯度清零? - Forever123的回答 - 知乎

https://www.zhihu.com/question/303070254/answer/608153308

原因在于在PyTorch中,計算得到的梯度值會進行累加。

而這樣的好處可以從記憶體消耗的角度來看。

1. Edition1

在PyTorch中,multi-task任務一個标準的train from scratch流程為:

for idx, data in enumerate(train_loader):
    xs, ys = data
    pred1 = model1(xs)
    pred2 = model2(xs)

    loss1 = loss_fn1(pred1, ys)
    loss2 = loss_fn2(pred2, ys)

    ******
    loss = loss1 + loss2
    optmizer.zero_grad()
    loss.backward()
    ++++++
    optmizer.step()
           

從PyTorch的設計原理上來說,在每次進行前向計算得到pred時,會産生一個**用于梯度回傳的計算圖,這張圖儲存了進行back propagation需要的中間結果,當調用了 **

**.backward()**

 後,會從記憶體中将這張圖進行釋放。

  • 上述代碼執行到

    ******

     時,記憶體中是包含了兩張計算圖的,而随着求和得到loss,這兩張圖進行了合并,而且大小的變化可以忽略。
  • 執行到

    ++++++

     時,得到對應的grad值并且釋放記憶體。這樣,訓練時必須存儲兩張計算圖,而如果loss的來源組成更加複雜,記憶體消耗會更大。

2. Edition2

為了減小每次的記憶體消耗,借助梯度累加,又有 ,有如下變種。

for idx, data in enumerate(train_loader):
    xs, ys = data

    optmizer.zero_grad()

    # 計算d(l1)/d(x)
    pred1 = model1(xs) #生成graph1
    loss = loss_fn1(pred1, ys)
    loss.backward()  #釋放graph1

    # 計算d(l2)/d(x)
    pred2 = model2(xs)#生成graph2
    loss2 = loss_fn2(pred2, ys)
    loss.backward()  #釋放graph2

    # 使用d(l1)/d(x)+d(l2)/d(x)進行優化
    optmizer.step()
           

可以從代碼中看出,利用梯度累加,可以在最多儲存一張計算圖的情況下進行multi-task任務的訓練。

3. Other

另外一個理由就是在記憶體大小不夠的情況下疊加多個batch的grad作為一個大batch進行疊代,因為二者得到的梯度是等價的。

綜上可知,這種梯度累加的思路是對記憶體的極大友好,是由FAIR的設計理念出發的。

相關連結

  • PyTorch中在反向傳播前為什麼要手動将梯度清零? - 知乎:https://www.zhihu.com/question/303070254
  • https://discuss.pytorch.org/t/why-do-we-need-to-set-the-gradients-manually-to-zero-in-pytorch/4903/9

本文來自部落格園,作者:lart

創作不易,轉載請注明原文連結:https://www.cnblogs.com/lart/p/11628696.html

歡迎關注我的公衆号,文章更新提醒更及時哦:

【PyTorch】PyTorch中的梯度累加

繼續閱讀