天天看點

pytorch 基于 apex.amp 的混合精度訓練:原理介紹與實作

一、混合精度訓練介紹

所謂天下武功,唯快不破。我們在訓練模型時,往往受制于顯存空間隻能選取較小的 batch size,導緻訓練時間過長,使人逐漸煩躁。那麼有沒有可能在顯存空間不變的情況下提高訓練速度呢?混合精度訓練(Mixed Precision)便油然而生。

1、fp16與fp32

fp16(float16):Half-precision floating-point format 半精度浮點數

fp32(float32):單精度浮點數

fp64(float64):雙精度浮點數

fp16 與 fp32 的存儲方式和精度參考部落格:https://blog.csdn.net/qq_36533552/article/details/105885714

混合精度訓練的精髓在于在記憶體中用 fp16 做儲存和乘法進而加速計算,用 fp32 做累加避免舍入誤差。

2、為什麼要使用混合精度訓練?

神經網絡架構的計算核心是Tensor,pytorch 中定義一個Tensor其預設類型是fp32。目前大多數的深度學習模型使用的是 fp32 進行訓練,而混合精度訓練的方法則通過 fp16 進行深度學習模型訓練,進而減少了訓練深度學習模型所需的記憶體,同時由于 fp16 的運算比 fp32 運算更快,進而也進一步提高了硬體效率。總之,混合16位和32位的計算可以節約GPU顯存和加速神經網絡訓練。

此外,硬體的發展同樣也推動着模型計算的加速,随着Nvidia張量核心(Tensor Core)的普及,16bit計算也一步步走向成熟,低精度計算也是未來深度學習的一個重要趨勢。

總結一下就是:省存儲,省傳播,省計算。

3、使用fp16帶來的問題及解決方法

參考部落格:https://zhuanlan.zhihu.com/p/79887894

參考部落格:https://zhuanlan.zhihu.com/p/165152789

fp16 的優勢是存儲小、計算快、更好的利用CUDA裝置的Tensor Core。是以訓練的時候可以減少顯存的占用(可以增加batchsize了),同時訓練速度更快。

fp16 的劣勢是:數值範圍小(更容易Overflow / Underflow)、舍入誤差(Rounding Error,導緻一些微小的梯度資訊達不到16bit精度的最低分辨率。比如反向求導中很接近0的小數值梯度用fp16表示後變為0,進而導緻梯度消失,訓練停滞。

可見,當 fp16有優勢的時候就用 fp16,而為了消除 fp16 的劣勢,有兩種解決方案:

(1)梯度縮放,通過放大 loss 的值來防止梯度的 underflow(這隻是BP的時候傳遞梯度資訊使用,真正更新權重的時候還是要把放大的梯度再unscale回去)。也就是将loss值放大k倍,根據鍊式法則,反向傳播中的梯度也會放大k倍,原來不能被fp16表示的數就可以被fp16表示。

(2)由 pytorch 自動決定什麼時候用fp16,什麼時候用fp32, 一般用 fp16 做儲存和乘法進而加速計算,用 fp32 做累加避免舍入誤差。如在卷積和全連接配接操作中用fp16,在 Softmax 操作中用fp32 , 這是 amp 自動設定和計算的。

在神經網絡處理器NPU中,在前向計算,反向求導,梯度傳輸時候用fp16,參數更新階段将fp16參數加到參數的fp32副本上。下一輪疊代時,将fp32副本上的參數轉為fp16,用于前向計算。二者之間的轉換為NPU内部自動實作的,操作者不可見也無法幹預。

Loss Scale 分為靜态和動态 Loss Scale,動态 Loss Scale 會自動更改 Loss Scale 的縮放倍數。

二、apex介紹與安裝

apex的全稱是 A PyTorch Extension ,其實就是一種 pytorch 的拓展插件,其本身與混合精度并無關系。apex 是 Nvidia 開發的基于 PyTorch 的混合精度訓練加速神器,是以 Apex 必須在GPU上使用,而不能在CPU中使用。

apex包的nvidia官網介紹:https://developer.nvidia.com/blog/mixed-precision-training-deep-neural-networks/

amp 的全稱是 auto mixed precision,自動混合精度,是一個用來支援模型訓練在pytorch架構下使用混合精度進行加速訓練的拓展插件之類的庫。它最核心的東西在于低精度 fp16 , 它能夠提供一種可靠友好的方式使得模型在 fp16 精度下進行訓練。

從 apex 中引入 amp 的方法是: from apex import amp

pytorch 原生支援的 amp 的使用方法是:from torch.cuda.amp import autocast as autocast, GradScaler

apex安裝過程參考部落格: https://blog.csdn.net/qq_43799400/article/details/118943030

三、apex.amp 的使用

1、三行代碼實作 amp

隻需要在程式中加入這幾行代碼即可(引自apex文檔):

from apex import amp
model, optimizer = amp.initialize(model, optimizer,opt_level="O1",loss_scale=128.0) 
with amp.scale_loss(loss, optimizer) as scaled_loss:
    scaled_loss.backward()    
           

amp 是 pytorch 的自動混合精度,具體介紹可參考:https://zhuanlan.zhihu.com/p/165152789

scale 是縮放的意思,通過放大loss的值來防止梯度下溢,不過這隻是BP的時候傳遞梯度資訊使用,真正更新權重的時候還是要把放大的梯度再unscale回去。

2、參數配置

opt_level 參數:

O0:純FP32訓練,可以作為accuracy的baseline

O1:混合精度訓練,根據黑白名單自動決定使用 FP16 還是 FP32 進行計算。

O2:“幾乎FP16”混合精度訓練,不存在黑白名單,除了Batch norm,幾乎都是用FP16計算。

O3:純FP16訓練,很不穩定,但是可以作為speed的baseline

說明:

推薦優先使用 opt_level=‘O2’, loss_scale=128.0 的配置進行amp.initialize

若無法收斂推薦使用 opt_level=‘O1’, loss_scale=128.0 的配置進行amp.initialize

若依然無法收斂推薦使用 opt_level=‘O1’, loss_scale=None 的配置進行amp.initialize

3、amp測試:mnist手寫數字識别

代碼:

import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

############################
# edit this for amp
from apex import amp
############################


parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=8, metavar='N',
                    help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                    help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=1, metavar='N',
                    help='number of epochs to train (default: 14)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                    help='learning rate (default: 1.0)')
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                    help='Learning rate step gamma (default: 0.7)')
parser.add_argument('--dry-run', action='store_true', default=False,
                    help='quickly check a single pass')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('--save-model', default=True,
                    help='For Saving the current Model')
args = parser.parse_args()

device = torch.device('cuda:0')
torch.manual_seed(args.seed)
train_kwargs = {'batch_size': args.batch_size}
test_kwargs = {'batch_size': args.test_batch_size}
cuda_kwargs = {'num_workers': 1,'pin_memory': True,'shuffle': True}
train_kwargs.update(cuda_kwargs)
test_kwargs.update(cuda_kwargs)
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
dataset1 = datasets.MNIST('./ms', train=True, download=True,transform=transform)
dataset2 = datasets.MNIST('./ms', train=False,transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output


def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)

        #################################################
        # edit this for amp
        with amp.scale_loss(loss, optimizer) as scaled_loss:
            scaled_loss.backward()
        # loss.backward()
        #################################################

        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            if args.dry_run:
                break


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

model = Net().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)


##############################################################################################3
#add this for amp
opt_level = 'O2'
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level,loss_scale=128.0)
###############################################################################################


scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
for epoch in range(1, args.epochs + 1):
    train(args, model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)
    scheduler.step()
if args.save_model:
    torch.save(model.state_dict(), "mnist_cnn.pt")
    # torch.save(model, "mnist_cnn.pt")  會報錯,隻能儲存模型參數,不能儲存模型
           

【注】經過 Apex 的 model 不能貌似儲存模型,隻能儲存模型參數。是以不能用 torch.save(model, ‘model.pt’) 儲存模型,隻能用 torch.save(model.state_dict(), ‘model.pt’) 儲存模型參數。原因不詳。

四、參考資料推薦

【PyTorch】唯快不破:基于Apex的混合精度加速:https://zhuanlan.zhihu.com/p/79887894

PyTorch的自動混合精度(AMP):https://zhuanlan.zhihu.com/p/165152789

fp16與fp32簡介與試驗:https://blog.csdn.net/qq_36533552/article/details/105885714

pytorch原生支援的apex混合精度和nvidia apex混合精度AMP技術加速模型訓練效果對比:https://blog.csdn.net/HUSTHY/article/details/109485088

繼續閱讀