天天看點

PyTorch基于Apex的混合精度加速

安裝:pip install apex

參考: https://blog.csdn.net/c9Yv2cf9I06K2A9E/article/details/100135729

在這篇文章裡,筆者會詳解一下混合精度計算(Mixed Precision),并介紹一款 NVIDIA 開發的基于 PyTorch 的混合精度訓練加速神器——Apex,最近 Apex 更新了 API,可以用短短三行代碼就能實作不同程度的混合精度加速,訓練時間直接縮小一半。 

話不多說,直接先教你怎麼用。

PyTorch實作

from apex import amp

model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # 這裡是“歐一”,不是“零一”

with amp.scale_loss(loss, optimizer) as scaled_loss:

scaled_loss.backward()

對,就是這麼簡單,如果你不願意花時間深入了解,讀到這基本就可以直接使用起來了。

但是如果你希望對 FP16 和 Apex 有更深入的了解,或是在使用中遇到了各種不明是以的“Nan”的同學,可以接着讀下去,後面會有一些有趣的理論知識和筆者最近一個月使用 Apex 遇到的各種 bug,不過當你深入了解并解決掉這些 bug 後,你就可以徹底擺脫“慢吞吞”的 FP32 啦。

理論部分

為了充分了解混合精度的原理,以及 API 的