天天看点

pytorch 基于 apex.amp 的混合精度训练:原理介绍与实现

一、混合精度训练介绍

所谓天下武功,唯快不破。我们在训练模型时,往往受制于显存空间只能选取较小的 batch size,导致训练时间过长,使人逐渐烦躁。那么有没有可能在显存空间不变的情况下提高训练速度呢?混合精度训练(Mixed Precision)便油然而生。

1、fp16与fp32

fp16(float16):Half-precision floating-point format 半精度浮点数

fp32(float32):单精度浮点数

fp64(float64):双精度浮点数

fp16 与 fp32 的存储方式和精度参考博客:https://blog.csdn.net/qq_36533552/article/details/105885714

混合精度训练的精髓在于在内存中用 fp16 做储存和乘法从而加速计算,用 fp32 做累加避免舍入误差。

2、为什么要使用混合精度训练?

神经网络框架的计算核心是Tensor,pytorch 中定义一个Tensor其默认类型是fp32。目前大多数的深度学习模型使用的是 fp32 进行训练,而混合精度训练的方法则通过 fp16 进行深度学习模型训练,从而减少了训练深度学习模型所需的内存,同时由于 fp16 的运算比 fp32 运算更快,从而也进一步提高了硬件效率。总之,混合16位和32位的计算可以节约GPU显存和加速神经网络训练。

此外,硬件的发展同样也推动着模型计算的加速,随着Nvidia张量核心(Tensor Core)的普及,16bit计算也一步步走向成熟,低精度计算也是未来深度学习的一个重要趋势。

总结一下就是:省存储,省传播,省计算。

3、使用fp16带来的问题及解决方法

参考博客:https://zhuanlan.zhihu.com/p/79887894

参考博客:https://zhuanlan.zhihu.com/p/165152789

fp16 的优势是存储小、计算快、更好的利用CUDA设备的Tensor Core。因此训练的时候可以减少显存的占用(可以增加batchsize了),同时训练速度更快。

fp16 的劣势是:数值范围小(更容易Overflow / Underflow)、舍入误差(Rounding Error,导致一些微小的梯度信息达不到16bit精度的最低分辨率。比如反向求导中很接近0的小数值梯度用fp16表示后变为0,从而导致梯度消失,训练停滞。

可见,当 fp16有优势的时候就用 fp16,而为了消除 fp16 的劣势,有两种解决方案:

(1)梯度缩放,通过放大 loss 的值来防止梯度的 underflow(这只是BP的时候传递梯度信息使用,真正更新权重的时候还是要把放大的梯度再unscale回去)。也就是将loss值放大k倍,根据链式法则,反向传播中的梯度也会放大k倍,原来不能被fp16表示的数就可以被fp16表示。

(2)由 pytorch 自动决定什么时候用fp16,什么时候用fp32, 一般用 fp16 做储存和乘法从而加速计算,用 fp32 做累加避免舍入误差。如在卷积和全连接操作中用fp16,在 Softmax 操作中用fp32 , 这是 amp 自动设定和计算的。

在神经网络处理器NPU中,在前向计算,反向求导,梯度传输时候用fp16,参数更新阶段将fp16参数加到参数的fp32副本上。下一轮迭代时,将fp32副本上的参数转为fp16,用于前向计算。二者之间的转换为NPU内部自动实现的,操作者不可见也无法干预。

Loss Scale 分为静态和动态 Loss Scale,动态 Loss Scale 会自动更改 Loss Scale 的缩放倍数。

二、apex介绍与安装

apex的全称是 A PyTorch Extension ,其实就是一种 pytorch 的拓展插件,其本身与混合精度并无关系。apex 是 Nvidia 开发的基于 PyTorch 的混合精度训练加速神器,因此 Apex 必须在GPU上使用,而不能在CPU中使用。

apex包的nvidia官网介绍:https://developer.nvidia.com/blog/mixed-precision-training-deep-neural-networks/

amp 的全称是 auto mixed precision,自动混合精度,是一个用来支持模型训练在pytorch框架下使用混合精度进行加速训练的拓展插件之类的库。它最核心的东西在于低精度 fp16 , 它能够提供一种可靠友好的方式使得模型在 fp16 精度下进行训练。

从 apex 中引入 amp 的方法是: from apex import amp

pytorch 原生支持的 amp 的使用方法是:from torch.cuda.amp import autocast as autocast, GradScaler

apex安装过程参考博客: https://blog.csdn.net/qq_43799400/article/details/118943030

三、apex.amp 的使用

1、三行代码实现 amp

只需要在程序中加入这几行代码即可(引自apex文档):

from apex import amp
model, optimizer = amp.initialize(model, optimizer,opt_level="O1",loss_scale=128.0) 
with amp.scale_loss(loss, optimizer) as scaled_loss:
    scaled_loss.backward()    
           

amp 是 pytorch 的自动混合精度,具体介绍可参考:https://zhuanlan.zhihu.com/p/165152789

scale 是缩放的意思,通过放大loss的值来防止梯度下溢,不过这只是BP的时候传递梯度信息使用,真正更新权重的时候还是要把放大的梯度再unscale回去。

2、参数配置

opt_level 参数:

O0:纯FP32训练,可以作为accuracy的baseline

O1:混合精度训练,根据黑白名单自动决定使用 FP16 还是 FP32 进行计算。

O2:“几乎FP16”混合精度训练,不存在黑白名单,除了Batch norm,几乎都是用FP16计算。

O3:纯FP16训练,很不稳定,但是可以作为speed的baseline

说明:

推荐优先使用 opt_level=‘O2’, loss_scale=128.0 的配置进行amp.initialize

若无法收敛推荐使用 opt_level=‘O1’, loss_scale=128.0 的配置进行amp.initialize

若依然无法收敛推荐使用 opt_level=‘O1’, loss_scale=None 的配置进行amp.initialize

3、amp测试:mnist手写数字识别

代码:

import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

############################
# edit this for amp
from apex import amp
############################


parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=8, metavar='N',
                    help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                    help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=1, metavar='N',
                    help='number of epochs to train (default: 14)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                    help='learning rate (default: 1.0)')
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                    help='Learning rate step gamma (default: 0.7)')
parser.add_argument('--dry-run', action='store_true', default=False,
                    help='quickly check a single pass')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('--save-model', default=True,
                    help='For Saving the current Model')
args = parser.parse_args()

device = torch.device('cuda:0')
torch.manual_seed(args.seed)
train_kwargs = {'batch_size': args.batch_size}
test_kwargs = {'batch_size': args.test_batch_size}
cuda_kwargs = {'num_workers': 1,'pin_memory': True,'shuffle': True}
train_kwargs.update(cuda_kwargs)
test_kwargs.update(cuda_kwargs)
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
dataset1 = datasets.MNIST('./ms', train=True, download=True,transform=transform)
dataset2 = datasets.MNIST('./ms', train=False,transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output


def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)

        #################################################
        # edit this for amp
        with amp.scale_loss(loss, optimizer) as scaled_loss:
            scaled_loss.backward()
        # loss.backward()
        #################################################

        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            if args.dry_run:
                break


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

model = Net().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)


##############################################################################################3
#add this for amp
opt_level = 'O2'
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level,loss_scale=128.0)
###############################################################################################


scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
for epoch in range(1, args.epochs + 1):
    train(args, model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)
    scheduler.step()
if args.save_model:
    torch.save(model.state_dict(), "mnist_cnn.pt")
    # torch.save(model, "mnist_cnn.pt")  会报错,只能保存模型参数,不能保存模型
           

【注】经过 Apex 的 model 不能貌似保存模型,只能保存模型参数。因此不能用 torch.save(model, ‘model.pt’) 保存模型,只能用 torch.save(model.state_dict(), ‘model.pt’) 保存模型参数。原因不详。

四、参考资料推荐

【PyTorch】唯快不破:基于Apex的混合精度加速:https://zhuanlan.zhihu.com/p/79887894

PyTorch的自动混合精度(AMP):https://zhuanlan.zhihu.com/p/165152789

fp16与fp32简介与试验:https://blog.csdn.net/qq_36533552/article/details/105885714

pytorch原生支持的apex混合精度和nvidia apex混合精度AMP技术加速模型训练效果对比:https://blog.csdn.net/HUSTHY/article/details/109485088

继续阅读