天天看點

給Bert加速吧!NLP中的知識蒸餾論文 Distilled BiLSTM解讀

論文題目:Distilling Task-Specific Knowledge from BERT into Simple Neural Networks

論文連結:https://arxiv.org/pdf/1903.12136.pdf

摘要

在自然語言處理文獻中,神經網絡變得越來越深入和複雜。這一趨勢的苗頭就是深度語言表示模型,其中包括BERT、ELMo和GPT。這些模型的出現和演進甚至導緻人們相信上一代、較淺的語言了解神經網絡(例如LSTM)已經過時了。然而這篇論文證明了如果沒有網絡架構的改變、不加入外部訓練資料或其他的輸入特征,基本的“輕量級”神經網絡仍然可以具有競争力。文本将最先進的語言表示模型BERT中的知識提煉為單層BiLSTM,以及用于句子對任務的暹羅對應模型。在語義了解、自然語言推理和情緒分類的多個資料集中,

知識蒸餾

模型獲得了與ELMo的相當結果,參數量隻有ELMo的大約1/100倍,而推理時間快了15倍。

1 簡介

關于自然語言處理研究中,神經網絡模型已經成了主力軍,并且模型結構層出不窮,好像永無止境一樣,這些過程中最開始的神經網絡例如LSTM變得容易被忽視。例如ELMo模型在2018年一些列任務上取得了sota效果,再到雙向編碼表示模型Bert、GPT-2在更多任務上取得了很大提升。

但是如此之大的模型在實踐落地的過程中是存在問題的:

  • 由于參數量特别大,例如 BERT 和 GPT-2,在移動裝置等資源受限的系統中是不可部署的。
  • 由于推理時間效率低,它們也可能不适用于實時系統,對于QPS壓測很多場景基本是不過關的。
  • 根據摩爾定律可知,我們需要在一定時間過後重新壓縮模型以及重新評估模型性能。

針對上述問題,本文提出了一種基于領域知識的高效遷移學習方法:

  • 作者将BERT-large蒸餾到了單層的BiLSTM中,參數量減少了100倍,速度提升了15倍,效果雖然比BERT差不少,但可以和ELMo打成平手。
  • 同時因為任務資料有限,作者基于以下規則進行了10+倍的資料擴充:用[MASK]随機替換單詞;基于POS标簽替換單詞;從樣本中随機取出n-gram作為新的樣本

2 相關工作

關于模型壓縮的背景介紹,大家可以看下 李rumor的文章https://zhuanlan.zhihu.com/p/273378905,總結比較精煉和到位,這裡不再重複贅述:

Hinton在NIPS2014[1]提出了知識蒸餾(Knowledge Distillation)的概念,旨在把一個大模型或者多個模型ensemble學到的知識遷移到另一個輕量級單模型上,友善部署。簡單的說就是用小模型去學習大模型的預測結果,而不是直接學習訓練集中的label。
在蒸餾的過程中,我們将原始大模型稱為教師模型(teacher),新的小模型稱為學生模型(student),訓練集中的标簽稱為hard label,教師模型預測的機率輸出為soft label,temperature(T)是用來調整soft label的超參數。
蒸餾這個概念之是以work,核心思想是因為好模型的目标不是拟合訓練資料,而是學習如何泛化到新的資料。是以蒸餾的目标是讓學生模型學習到教師模型的泛化能力,理論上得到的結果會比單純拟合訓練資料的學生模型要好。

在BERT提出後,如何瘦身就成了一個重要分支。主流的方法主要有剪枝、蒸餾和量化。量化的提升有限,是以免不了采用剪枝+蒸餾的融合方法來擷取更好的效果。接下來将介紹BERT蒸餾的主要發展脈絡,從各個研究看來,蒸餾的提升一方面來源于從精調階段蒸餾->預訓練階段蒸餾,另一方面則來源于蒸餾最後一層知識->蒸餾隐層知識->蒸餾注意力矩陣。

3 模型方法

本篇論文第一步選擇teacher 模型和student模型,第二步确立蒸餾程式:确立logit-regression目标函數和遷移資料集建構。

3.1 模型選擇

對于“teacher”模型,本文選擇Bert去做微調任務,比如文本分類,文本對分類等。對文本分類,可以直接将文本輸入到bert,拿到cls輸出直接softmax,可以得到每個标簽機率:

給Bert加速吧!NLP中的知識蒸餾論文 Distilled BiLSTM解讀

,其中W\in R^{k *d}

是softmax權重矩陣,k是類别個數。對于文本對任務,我們可以直接兩個文本輸入到Bert提取特征,然後收入到softmax進行分類。

對于“student”模型,本文選擇的是BiLSTM和一個非線性分類器。如下圖所示:

給Bert加速吧!NLP中的知識蒸餾論文 Distilled BiLSTM解讀
給Bert加速吧!NLP中的知識蒸餾論文 Distilled BiLSTM解讀

主要流程是将文本詞向量表示,輸入到BiLSTM,選取正向和反向最後時刻的隐藏層輸出并進行拼接,然後經過一個relu輸出,輸入到softmax得到最後的機率。

3.2 蒸餾目标

給Bert加速吧!NLP中的知識蒸餾論文 Distilled BiLSTM解讀

其中w_{i} 是權重矩陣 W 的第i行, z 等于w^Th

蒸餾的目标就是為了最小化student模型與teacher模型的平方誤差MSE:

給Bert加速吧!NLP中的知識蒸餾論文 Distilled BiLSTM解讀

其中Z(B) 和Z(S) 分類代表teacher和student模型的logit輸出

最終蒸餾模型的訓練函數可以将MSE損失和交叉熵損失結合起來:

給Bert加速吧!NLP中的知識蒸餾論文 Distilled BiLSTM解讀

3.3 資料增強

  • 用[MASK]随機替換單詞:“I loved the comedy.”變成“I [MASK] the comedy”
  • 基于POS标簽替換單詞;“What do pigs eat?” 變成“How do pigs eat?”
  • 從樣本中随機取出n-gram作為新的樣本

4 實驗結果

本文采用的資料集為SST-2、MNLI、QQP

實驗結果如下:

給Bert加速吧!NLP中的知識蒸餾論文 Distilled BiLSTM解讀

推理更加快:

給Bert加速吧!NLP中的知識蒸餾論文 Distilled BiLSTM解讀

5 蒸餾代碼

https://github.com/qiangsiwei/bert_distill

# coding:utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from keras.preprocessing import sequence
import pickle
from tqdm import tqdm
import numpy as np
from transformers import BertTokenizer
from utils import load_data
from bert_finetune import BertClassification


USE_CUDA = torch.cuda.is_available()
if USE_CUDA: torch.cuda.set_device(0)
FTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor
LTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor
device = torch.device('cuda' if USE_CUDA else 'cpu')

class RNN(nn.Module):
    def __init__(self, x_dim, e_dim, h_dim, o_dim):
        super(RNN, self).__init__()
        self.h_dim = h_dim
        self.dropout = nn.Dropout(0.2)
        self.emb = nn.Embedding(x_dim, e_dim, padding_idx=0)
        self.lstm = nn.LSTM(e_dim, h_dim, bidirectional=True, batch_first=True)
        self.fc = nn.Linear(h_dim * 2, o_dim)
        self.softmax = nn.Softmax(dim=1)
        self.log_softmax = nn.LogSoftmax(dim=1)

    def forward(self, x):
        embed = self.dropout(self.emb(x))
        out, _ = self.lstm(embed)
        hidden = self.fc(out[:, -1, :])
        return self.softmax(hidden), self.log_softmax(hidden)


class Teacher(object):
    def __init__(self, bert_model='bert-base-chinese', max_seq=128, model_dir=None):
        self.max_seq = max_seq
        self.tokenizer = BertTokenizer.from_pretrained(bert_model, do_lower_case=True)
        self.model = torch.load(model_dir)
        self.model.eval()

    def predict(self, text):
        tokens = self.tokenizer.tokenize(text)[:self.max_seq]
        input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
        input_mask = [1] * len(input_ids)
        padding = [0] * (self.max_seq - len(input_ids))
        input_ids = torch.tensor([input_ids + padding], dtype=torch.long).to(device)
        input_mask = torch.tensor([input_mask + padding], dtype=torch.long).to(device)
        logits = self.model(input_ids, input_mask, None)
        return F.softmax(logits, dim=1).detach().cpu().numpy()


def train_student(bert_model_dir="/data0/sina_up/dajun1/src/doc_dssm/sentence_bert/bert_pytorch",
                  teacher_model_path="./model/teacher.pth",
                  student_model_path="./model/student.pth",
                  data_dir="data/hotel",
                  vocab_path="data/char.json",
                  max_len=50,
                  batch_size=64,
                  lr=0.002,
                  epochs=10,
                  alpha=0.5):

    teacher = Teacher(bert_model=bert_model_dir, model_dir=teacher_model_path)
    teach_on_dev = True
    (x_tr, y_tr, t_tr), (x_de, y_de, t_de), vocab_size = load_data(data_dir, vocab_path)

    l_tr = list(map(lambda x: min(len(x), max_len), x_tr))
    l_de = list(map(lambda x: min(len(x), max_len), x_de))

    x_tr = sequence.pad_sequences(x_tr, maxlen=max_len)
    x_de = sequence.pad_sequences(x_de, maxlen=max_len)

    with torch.no_grad():
        t_tr = np.vstack([teacher.predict(text) for text in t_tr])
        t_de = np.vstack([teacher.predict(text) for text in t_de])

    with open(data_dir+'/t_tr', 'wb') as fout: pickle.dump(t_tr,fout)
    with open(data_dir+'/t_de', 'wb') as fout: pickle.dump(t_de,fout)

    model = RNN(vocab_size, 256, 256, 2)

    if USE_CUDA: model = model.cuda()
    opt = optim.Adam(model.parameters(), lr=lr)
    ce_loss = nn.NLLLoss()
    mse_loss = nn.MSELoss()
    for epoch in range(epochs):
        losses, accuracy = [], []
        model.train()
        for i in range(0, len(x_tr), batch_size):
            model.zero_grad()
            bx = Variable(LTensor(x_tr[i:i + batch_size]))
            by = Variable(LTensor(y_tr[i:i + batch_size]))
            bl = Variable(LTensor(l_tr[i:i + batch_size]))
            bt = Variable(FTensor(t_tr[i:i + batch_size]))
            py1, py2 = model(bx)
            loss = alpha * ce_loss(py2, by) + (1-alpha) * mse_loss(py1, bt)  # in paper, only mse is used
            loss.backward()
            opt.step()
            losses.append(loss.item())
        for i in range(0, len(x_de), batch_size):
            model.zero_grad()
            bx = Variable(LTensor(x_de[i:i + batch_size]))
            bl = Variable(LTensor(l_de[i:i + batch_size]))
            bt = Variable(FTensor(t_de[i:i + batch_size]))
            py1, py2 = model(bx)
            loss = mse_loss(py1, bt)
            if teach_on_dev:
                loss.backward()             
                opt.step()
            losses.append(loss.item())
        model.eval()
        with torch.no_grad():
            for i in range(0, len(x_de), batch_size):
                bx = Variable(LTensor(x_de[i:i + batch_size]))
                by = Variable(LTensor(y_de[i:i + batch_size]))
                bl = Variable(LTensor(l_de[i:i + batch_size]))
                _, py = torch.max(model(bx, bl)[1], 1)
                accuracy.append((py == by).float().mean().item())
        print(np.mean(losses), np.mean(accuracy))
    torch.save(model, student_model_path)


if __name__ == "__main__":
    train_student()            

複制

參考連結

  • 【經典簡讀】知識蒸餾(Knowledge Distillation) 經典之作
  • 【論文筆記】Distilling Task-Specific Knowledge from BERT into Simple Neural Networks
  • 知識蒸餾論文選讀(二)