安裝:pip install apex
參考: https://blog.csdn.net/c9Yv2cf9I06K2A9E/article/details/100135729
在這篇文章裡,筆者會詳解一下混合精度計算(Mixed Precision),并介紹一款 NVIDIA 開發的基于 PyTorch 的混合精度訓練加速神器——Apex,最近 Apex 更新了 API,可以用短短三行代碼就能實作不同程度的混合精度加速,訓練時間直接縮小一半。
話不多說,直接先教你怎麼用。
PyTorch實作
from apex import amp
model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # 這裡是“歐一”,不是“零一”
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
對,就是這麼簡單,如果你不願意花時間深入了解,讀到這基本就可以直接使用起來了。
但是如果你希望對 FP16 和 Apex 有更深入的了解,或是在使用中遇到了各種不明是以的“Nan”的同學,可以接着讀下去,後面會有一些有趣的理論知識和筆者最近一個月使用 Apex 遇到的各種 bug,不過當你深入了解并解決掉這些 bug 後,你就可以徹底擺脫“慢吞吞”的 FP32 啦。
理論部分
為了充分了解混合精度的原理,以及 API 的