天天看点

PyTorch基于Apex的混合精度加速

安装:pip install apex

参考: https://blog.csdn.net/c9Yv2cf9I06K2A9E/article/details/100135729

在这篇文章里,笔者会详解一下混合精度计算(Mixed Precision),并介绍一款 NVIDIA 开发的基于 PyTorch 的混合精度训练加速神器——Apex,最近 Apex 更新了 API,可以用短短三行代码就能实现不同程度的混合精度加速,训练时间直接缩小一半。 

话不多说,直接先教你怎么用。

PyTorch实现

from apex import amp

model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # 这里是“欧一”,不是“零一”

with amp.scale_loss(loss, optimizer) as scaled_loss:

scaled_loss.backward()

对,就是这么简单,如果你不愿意花时间深入了解,读到这基本就可以直接使用起来了。

但是如果你希望对 FP16 和 Apex 有更深入的了解,或是在使用中遇到了各种不明所以的“Nan”的同学,可以接着读下去,后面会有一些有趣的理论知识和笔者最近一个月使用 Apex 遇到的各种 bug,不过当你深入理解并解决掉这些 bug 后,你就可以彻底摆脱“慢吞吞”的 FP32 啦。

理论部分

为了充分理解混合精度的原理,以及 API 的