一. torch.nn.DataParallel ?
pytorch單機多卡最簡單的實作方法就是使用nn.DataParallel類,其幾乎僅使用一行代碼
net = torch.nn.DataParallel(net)
就可讓模型同時在多張GPU上訓練,它大緻的工作過程如下圖所示: ![](https://img.laitimes.com/img/9ZDMuAjOiMmIsIjOiQnIsISPrdEZwZ1Rh5WNXp1bwNjW1ZUba9VZwlHdsATOfd3bkFGazxCMx8VesATMfhHLlN3XnxCMwEzX0xiRGZkRGZ0Xy9GbvNGLpZTY1EmMZVDUSFTU4VFRR9Fd4VGdsYTMfVmepNHLrJXYtJXZ0F2dvwVZnFWbp1zczV2YvJHctM3cv1Ce-cmbw5CM3IDM1MGMyMmN0EjM1IDOxYzX0QDMxcTMzAzLcFTMxIDMy8CXn9Gbi9CXzV2Zh1WavwVbvNmLvR3YxUjL4M3Lc9CX6MHc0RHaiojIsJye.png)
在每一個Iteration的Forward過程中,nn.DataParallel都自動将輸入按照gpu_batch進行split,然後複制模型參數到各個GPU上,分别進行前傳後将得到網絡輸出,最後将結果concat到一起送往0号卡中。
在Backward過程中,先由0号卡計算loss函數,通過
loss.backward()
得到損失函數相于各個gpu輸出結果的梯度grad_l1 ... gradln,接下來0号卡将所有的grad_l送回對應的GPU中,然後GPU們分别進行backward得到各個GPU上面的模型參數梯度值gradm1 ... gradmn,最後所有參數的梯度彙總到GPU0卡進行update。 - 負載不均衡問題。gpu0所承擔的任務明顯要重于其他gpu
- 速度問題。每個iteration都需要複制模型且均從GPU0卡向其他GPU複制,通訊任務重且效率低;python多線程GIL鎖導緻的線程颠簸(thrashing)問題。
- 隻能單機運作。由于單程序的限制導緻。
- 隻能切分batch到多GPU,而無法讓一個model分布在多個GPU上。當一個模型過大,設定batchsize=1時其顯存占用仍然大于單張顯示卡顯存,此時就無法使用DataParallel類進行訓練。