天天看點

pytorch中的梯度更新

背景

使用pytorch時,有一個yolov3的bug,我認為涉及到學習率的調整。收集到tencent yolov3和mxnet開源的yolov3,兩個優化器中的學習率設定不一樣,而且使用GPU數目和batch的更新也不太一樣。據此,我簡單的了解了下pytorch的權重梯度的更新政策,看看能否一窺究竟。

對代碼說明

共三個實驗,分布寫在代碼中的(一)(二)(三)三個地方。運作實驗時注釋掉其他兩個

實驗及其結果

實驗(三):

不使用zero_grad()時,grad累加在一起,官網是使用accumulate 來表述的,是以不太清楚是取的和還是均值(這兩種最有可能)。

不使用zero_grad()時,是直接疊加add的方式累加的。

tensor([[[ 1.,  1.],……torch.Size([2, 2, 2])
0 2 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * 
tensor([[[ 2.,  2.],…… torch.Size([2, 2, 2])
1 2 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * 
tensor([[[ 3.,  3.],…… torch.Size([2, 2, 2])
2 2 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * 
           
實驗(二):

單卡上不同的batchsize對梯度是怎麼作用的。 mini-batch SGD中的batch是加快訓練,同時保持一定的噪聲。但設定不同的batchsize的權重的梯度是怎麼計算的呢。

設定運作實驗(二),可以看到結果如下:是以單卡batchsize計算梯度是取均值的

tensor([[[ 3.,  3.],…… torch.Size([2, 2, 2])
           
實驗(一):

多gpu情況下,梯度怎麼合并在一起的。

在《training imagenet in 1 hours》中提到grad是allreduce的,是累加的形式。但是當設定g=2,實驗一運作時,結果也是取均值的,類同于實驗(二)

tensor([[[ 3.,  3.],…… torch.Size([2, 2, 2])
           

實驗代碼

import torch
import torch.nn as nn
from torch.autograd import Variable


class model(nn.Module):
    def __init__(self, w):
        super(model, self).__init__()
        self.w = w

    def forward(self, xx):
        b, c, _, _ = xx.shape
        # extra = xx.device.index + 1 ##  實驗(一)
        y = xx.reshape(b, -1).mm(self.w.cuda(xx.device).reshape(-1, 2) * extra)
        return y.reshape(len(xx), -1)


g = 1
x = Variable(torch.ones(2, 1, 2, 2))
# x[1] += 1 ## 實驗(二)
w = Variable(torch.ones(2, 2, 2) * 2, requires_grad=True)
# optim = torch.optim.SGD({'params': x},
lr = 0.01
momentum = 0.9
M = model(w)

M = torch.nn.DataParallel(M, device_ids=range(g))

for i in range(3):
    b = len(x)
    z = M(x)
    zz = z.sum(1)
    l = (zz - Variable(torch.ones(b).cuda())).mean()
    # zz.backward(Variable(torch.ones(b).cuda()))
    l.backward()
    print(w.grad, w.grad.shape)
    # w.grad.zero_() ## 實驗(三)
    print(i, b, '* * ' * 20)

           

繼續閱讀