天天看點

使用Pytorch進行單機多卡分布式訓練

一. torch.nn.DataParallel ?

pytorch單機多卡最簡單的實作方法就是使用nn.DataParallel類,其幾乎僅使用一行代碼

net = torch.nn.DataParallel(net)

就可讓模型同時在多張GPU上訓練,它大緻的工作過程如下圖所示:
使用Pytorch進行單機多卡分布式訓練

在每一個Iteration的Forward過程中,nn.DataParallel都自動将輸入按照gpu_batch進行split,然後複制模型參數到各個GPU上,分别進行前傳後将得到網絡輸出,最後将結果concat到一起送往0号卡中。

在Backward過程中,先由0号卡計算loss函數,通過

loss.backward()

得到損失函數相于各個gpu輸出結果的梯度grad_l1 ... gradln,接下來0号卡将所有的grad_l送回對應的GPU中,然後GPU們分别進行backward得到各個GPU上面的模型參數梯度值gradm1 ... gradmn,最後所有參數的梯度彙總到GPU0卡進行update。

  1. 負載不均衡問題。gpu0所承擔的任務明顯要重于其他gpu
  2. 速度問題。每個iteration都需要複制模型且均從GPU0卡向其他GPU複制,通訊任務重且效率低;python多線程GIL鎖導緻的線程颠簸(thrashing)問題。
  3. 隻能單機運作。由于單程序的限制導緻。
  4. 隻能切分batch到多GPU,而無法讓一個model分布在多個GPU上。當一個模型過大,設定batchsize=1時其顯存占用仍然大于單張顯示卡顯存,此時就無法使用DataParallel類進行訓練。

繼續閱讀