自然語言推理:微調BERT
Natural Language Inference: Fine-Tuning BERT
SNLI資料集上的自然語言推理任務設計了一個基于注意力的體系結構。現在通過微調BERT來重新讨論這個任務。自然語言推理是一個序列級文本對分類問題,而微調BERT隻需要額外的基于MLP的架構,如圖1所示。

Fig. 1. This section feeds pretrained BERT to an MLP-based architecture for natural language inference.
下載下傳一個經過預訓練的小版本BERT,然後對其進行微調,以便在SNLI資料集上進行自然語言推理。
from d2l import mxnet as d2l
import json
import multiprocessing
from mxnet import autograd, gluon, init, np, npx
from mxnet.gluon import nn
import os
npx.set_np()
1. Loading Pretrained BERT
解釋了如何在WikiText-2資料集上預訓練BERT(注意,原始的BERT模型是在更大的語料庫上預訓練的)。最初的BERT模型有上億個參數。提供兩個版本的預訓練BERT:“bert.base “大約和原始的BERT基模型一樣大,需要大量的計算資源進行微調,而“bert.small”是一個小版本,便于示範。
d2l.DATA_HUB['bert.base'] = (d2l.DATA_URL + 'bert.base.zip',
'7b3820b35da691042e5d34c0971ac3edbd80d3f4')
d2l.DATA_HUB['bert.small'] = (d2l.DATA_URL + 'bert.small.zip',
'a4e718a47137ccd1809c9107ab4f5edd317bae2c')
任何一個預訓練的BERT模型都包含一個“vocab.json”定義詞彙集和“pretrained.params”預訓練參數的檔案。實作了如下加載預訓練模型函數來load_pretrained_model加載預訓練的BERT參數。
def load_pretrained_model(pretrained_model, num_hiddens, ffn_num_hiddens,
num_heads, num_layers, dropout, max_len, ctx):
data_dir = d2l.download_extract(pretrained_model)
# Define an empty vocabulary to load the predefined vocabulary
vocab = d2l.Vocab([])
vocab.idx_to_token = json.load(open(os.path.join(data_dir, 'vocab.json')))
vocab.token_to_idx = {token: idx for idx, token in enumerate(
vocab.idx_to_token)}
bert = d2l.BERTModel(len(vocab), num_hiddens, ffn_num_hiddens, num_heads,
num_layers, dropout, max_len)
# Load pretrained BERT parameters
bert.load_parameters(os.path.join(data_dir, 'pretrained.params'), ctx=ctx)
return bert, vocab
為了便于在大多數機器上示範,将加載并微調小版本(“bert.small”)的名稱。在練習中,将示範如何微調更大的“bert.base”以顯著提高測試精度。
ctx = d2l.try_all_gpus()
bert, vocab = load_pretrained_model(
'bert.small', num_hiddens=256, ffn_num_hiddens=512, num_heads=4,
num_layers=2, dropout=0.1, max_len=512, ctx=ctx)
Downloading ../data/bert.small.zip from http://d2l-data.s3-accelerate.amazonaws.com/bert.small.zip...
2. The Dataset for Fine-Tuning BERT
對于SNLI資料集上的下遊任務自然語言推理,定義了一個自定義的資料集類SNLIBERTDataset。在每個例子中,前提和假設形成一對文本序列,并被打包成一個BERT輸入序列,如圖2所示。段IDs用于區分BERT輸入序列中的前提和假設。使用預定義的BERT輸入序列的最大長度(max_len),輸入文本對中較長的最後一個标記會一直被删除,直到滿足max_len。為了加速生成用于微調BERT的SNLI資料集,使用4個worker程序并行地生成訓練或測試示例。
class SNLIBERTDataset(gluon.data.Dataset):
def __init__(self, dataset, max_len, vocab=None):
all_premise_hypothesis_tokens = [[
p_tokens, h_tokens] for p_tokens, h_tokens in zip(
*[d2l.tokenize([s.lower() for s in sentences])
for sentences in dataset[:2]])]
self.labels = np.array(dataset[2])
self.vocab = vocab
self.max_len = max_len
(self.all_token_ids, self.all_segments,
self.valid_lens) = self._preprocess(all_premise_hypothesis_tokens)
print('read ' + str(len(self.all_token_ids)) + ' examples')
def _preprocess(self, all_premise_hypothesis_tokens):
pool = multiprocessing.Pool(4) # Use 4 worker processes
out = pool.map(self._mp_worker, all_premise_hypothesis_tokens)
all_token_ids = [
token_ids for token_ids, segments, valid_len in out]
all_segments = [segments for token_ids, segments, valid_len in out]
valid_lens = [valid_len for token_ids, segments, valid_len in out]
return (np.array(all_token_ids, dtype='int32'),
np.array(all_segments, dtype='int32'),
np.array(valid_lens))
def _mp_worker(self, premise_hypothesis_tokens):
p_tokens, h_tokens = premise_hypothesis_tokens
self._truncate_pair_of_tokens(p_tokens, h_tokens)
tokens, segments = d2l.get_tokens_and_segments(p_tokens, h_tokens)
token_ids = self.vocab[tokens] + [self.vocab['<pad>']] \
* (self.max_len - len(tokens))
segments = segments + [0] * (self.max_len - len(segments))
valid_len = len(tokens)
return token_ids, segments, valid_len
def _truncate_pair_of_tokens(self, p_tokens, h_tokens):
# Reserve slots for '<CLS>', '<SEP>', and '<SEP>' tokens for the BERT
# input
while len(p_tokens) + len(h_tokens) > self.max_len - 3:
if len(p_tokens) > len(h_tokens):
p_tokens.pop()
else:
h_tokens.pop()
def __getitem__(self, idx):
return (self.all_token_ids[idx], self.all_segments[idx],
self.valid_lens[idx]), self.labels[idx]
def __len__(self):
return len(self.all_token_ids)
在下載下傳SNLI資料集之後,通過執行個體化SNLIBERTDataset類來生成訓練和測試示例。這些例子将在自然語言推理的訓練和測試中分批閱讀。
# Reduce `batch_size` if there is an out of memory error. In the original BERT
# model, `max_len` = 512
batch_size, max_len, num_workers = 512, 128, d2l.get_dataloader_workers()
data_dir = d2l.download_extract('SNLI')
train_set = SNLIBERTDataset(d2l.read_snli(data_dir, True), max_len, vocab)
test_set = SNLIBERTDataset(d2l.read_snli(data_dir, False), max_len, vocab)
train_iter = gluon.data.DataLoader(train_set, batch_size, shuffle=True,
num_workers=num_workers)
test_iter = gluon.data.DataLoader(test_set, batch_size,
num_workers=num_workers)
read 549367 examples
read 9824 examples
3. Fine-Tuning BERT
如圖2所示,用于自然語言推理的微調BERT隻需要由兩個完全連接配接的層組成的額外MLP(參見自隐藏以及自輸出在下面的BERTClassifier類中)。這種MLP将編碼前提和假設資訊的特殊“<cls>”标記的BERT表示轉化為自然語言推理的三種輸出:蘊涵、沖突和中性。
class BERTClassifier(nn.Block):
def __init__(self, bert):
super(BERTClassifier, self).__init__()
self.encoder = bert.encoder
self.hidden = bert.hidden
self.output = nn.Dense(3)
def forward(self, inputs):
tokens_X, segments_X, valid_lens_x = inputs
encoded_X = self.encoder(tokens_X, segments_X, valid_lens_x)
return self.output(self.hidden(encoded_X[:, 0, :]))
接下來,将預訓練的BERT模型BERT輸入BERT分類器執行個體網絡,供下遊應用程式使用。在一般的BERT微調實作中,隻有輸出層的參數附加MLP(net.output)從零開始學習。預訓練BERT編碼器的所有參數(net.encoder)以及附加MLP的隐藏層(net.hidden)将進行微調。
net = BERTClassifier(bert)
net.output.initialize(ctx=ctx)
MaskLM類和NextSentencePred類在使用的mlp中都有參數。這些參數是預訓練BERT模型BERT的一部分,是以也是網絡中的一部分。然而,這些參數僅用于計算預訓練過程中的隐含語言模組化損失和下一句預測損失。這兩個損失函數與下遊應用的微調無關,是以當對BERT進行微調時,MaskLM和NextSentencePred中使用的MLPs的參數不會更新(過期)。
為了允許參數具有過時漸變,在d2l.train_batch_ch13的步進函數中設定标志ignore_stale_grad=True。利用SNLI的訓練集(train_iter)和測試集(test_iter)來訓練和評估模型網絡。由于計算資源有限,訓練和測試的準确性還有待進一步提高:将其讨論留在練習中。
lr, num_epochs = 1e-4, 5
trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': lr})
loss = gluon.loss.SoftmaxCrossEntropyLoss()
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, ctx,
d2l.split_batch_multi_inputs)
loss 0.597, train acc 0.741, test acc 0.713
- We can fine-tune the pretrained BERT model for downstream applications, such as natural language inference on the SNLI dataset.
- During fine-tuning, the BERT model becomes part of the model for the downstream application. Parameters that are only related to pretraining loss will not be updated during fine-tuning.
人工智能晶片與自動駕駛