天天看點

prune resnet18

最近在學習模型壓縮中的剪枝

但是對于怎麼實作剪枝不太了解

于是查找了别人的代碼,并在過程中加入自己的注釋了解

這次學習的是在resnet18訓練好的cifar-10 下進行的剪枝

代碼源于

https://github.com/kentaroy47/Deep-Compression.Pytorch

以下是prune子產品

# -*- coding: utf-8 -*-

'''Deep Compression with PyTorch.'''
from __future__ import print_function

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import argparse

from models import *
from utils import progress_bar

import numpy as np

parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Pruning')
parser.add_argument('--loadfile', '-l', default="checkpoint/res18.t7",dest='loadfile')
parser.add_argument('--prune', '-p', default=0.5, dest='prune', help='Parameters to be pruned')
parser.add_argument('--lr', default=0.01, type=float, help='learning rate')
parser.add_argument('--net', default='res18')
args = parser.parse_args()

prune = float(args.prune)  #prune = 0.5 剪去50%

device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=0)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=0)

# Model
print('==> Building model..')
if args.net=='res18':
    net = ResNet18()
elif args.net=='vgg':
    net = VGG('VGG19')
    
net = net.to(device)
if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True


# Load weights from checkpoint.
print('==> Resuming from checkpoint..')
assert os.path.isfile(args.loadfile), 'Error: no checkpoint directory found!'
checkpoint = torch.load(args.loadfile)   #dict
net.load_state_dict(checkpoint['net'])
#dict_keys(['acc', 'epoch', 'net', 'address', 'mask']), len(checkpoint) = 5
print(checkpoint.values())
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

def prune_weights(torchweights):
    weights=np.abs(torchweights.cpu().numpy());
    weightshape=weights.shape  #傳回一個整型數字的元組,元組中的每個元素表示相應的數組每一維的長度
    rankedweights=weights.reshape(weights.size).argsort()
    #.reshape(weightshape)  這裡應該是将weights化成一維形式,argsort()函數是将x中的元素從小到大排列,提取其對應的index(索引号) 
    
    num = weights.size
    prune_num = int(np.round(num*prune))
    print('prune_num:',prune_num)
    count=0
    masks = np.zeros_like(rankedweights)
    for n, rankedweight in enumerate(rankedweights):  #n是idx, rankweight 是idx對應的權重
        if rankedweight > prune_num:
            masks[n]=1
        else: count+=1
#        if n<15:
#            print("n, rankedweight:",n,'\t',rankedweight)
            
#    print('masks:',masks)
    print("total weights:", num)
    print("weights pruned:",count)
    
    masks=masks.reshape(weightshape)   #轉化成隻有1 and 0 的矩陣形式再與weights相乘即可将某些權重清零
    weights=masks*weights
    
    return torch.from_numpy(weights).cuda(), masks
'''for example
pruning layer: module.layer1.0.conv2.weight
prune_num: 18432
n, rankedweight: 0          14054
n, rankedweight: 1          1747
n, rankedweight: 2          31774
n, rankedweight: 3          16140
n, rankedweight: 4          1811
n, rankedweight: 5          35556
n, rankedweight: 6          16134
n, rankedweight: 7          1784
n, rankedweight: 8          7769
n, rankedweight: 9          1896
n, rankedweight: 10          16356
n, rankedweight: 11          2028
n, rankedweight: 12          1808
n, rankedweight: 13          30484
n, rankedweight: 14          30050
masks: [0 0 1 ... 1 0 1]
total weights: 36864
weights pruned: 18433
###############################
64
'''
#    print("rankedweights:",rankedweights)
    
# prune weights
# The pruned weight location is saved in the addressbook and maskbook.
# These will be used during training to keep the weights zero.
addressbook=[]
maskbook=[]
#items把字典的每一對key和value組成數組後以清單的形式傳回
for k, v in net.state_dict().items():
    if "conv2" in k:
        addressbook.append(k)
        # k = module.layer*.*.conv2.weight  字典名稱
        print("pruning layer:",k)
#        print('\t', v,v.size(1),'\t',v.size(2))
        weights=v  #矩陣 512 * 3
        weights, masks = prune_weights(weights)
#        print(len(masks))  #len = 64, 128, 256, 512
        maskbook.append(masks)
#        print(weights)
        checkpoint['net'][k] = weights
        
checkpoint['address'] = addressbook
checkpoint['mask'] = maskbook
net.load_state_dict(checkpoint['net'])

# Training

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=args.lr, weight_decay=5e-4)

def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        
        # mask pruned weights 
        checkpoint['net']=net.state_dict()
#        print("zeroing..")
#        print(np.count_nonzero(checkpoint['net'][addressbook[0]].cpu().numpy()))  
#        #count_nonzero  數module.layer1.0.conv2.weight  裡面weight != 0 的個數
        for address, mask in zip(addressbook, maskbook):
            print(address)
            checkpoint['net'][address] = torch.from_numpy(checkpoint['net'][address].cpu().numpy() * mask)
        print(checkpoint['net'][address])
        print(np.count_nonzero(checkpoint['net'][addressbook[0]].cpu().numpy()))  
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
            % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))

def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/pruned-'+args.net+'-ckpt.t7')
        best_acc = acc


if __name__ == '__main__':
    for epoch in range(start_epoch, start_epoch+20):
        train(epoch)
        test(epoch)
        with open("prune-results-"+str(prune)+'-'+str(args.net)+".txt", "a") as f: 
            f.write(str(epoch)+"\n")
            f.write(str(best_acc)+"\n")
           

自己還有很多不太懂的地方,記錄一下學習經曆,day day up

繼續閱讀