天天看點

超詳細的 Bert 文本分類源碼解讀 | 附源碼

本文詳細的GitHub位址:

https://github.com/sherlcok314159/ML

接上一篇:

你所不知道的 Transformer!

參考論文

https://arxiv.org/abs/1706.03762 https://arxiv.org/abs/1810.04805

在本文中,我将以run_classifier.py以及MRPC資料集為例介紹關于bert以及transformer的源碼,官方代碼基于tensorflow-gpu 1.x,若為tensorflow 2.x版本,會有各種錯誤,建議切換版本至1.14。

當然,注釋好的源代碼在這裡:

https://github.com/sherlcok314159/ML/tree/main/nlp/code

章節

  • Demo傳參
    • 跑不動?
  • 資料篇
    • 資料讀入
    • 資料處理
  • 詞處理
    • 切分
    • 詞向量編碼
  • TFRecord檔案建構
  • 模型建構
    • 詞向量拼接
      • 句子類型編碼
      • 位置編碼
    • 多頭注意力
      • MASK機制
      • Q,K,V矩陣建構
    • 損失優化
    • 構模組化型
    • 其他注意點

首先大家拿到這個模型,管他什麼原理,肯定想跑起來看看結果,至于預訓練模型以及資料集下載下傳。任何時候應該先看官方教程:

https://github.com/google-research/bert

官方代表着權威,更容易實作,如果遇到問題可以去issues和stackoverflow看看,再輔以中文教程,一般上手就不難了,這裡就不再贅述了。

先從Flags參數講起,到如何跑通demo。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

拿到源碼不要慌張,英文注釋往往起着最關鍵的作用,另外閱讀源碼詳細技巧可以看源碼技巧:

https://github.com/sherlcok314159/ML/blob/main/nlp/source_code.md

"Required Parameters"意思是必要參數,你等會執行時必須向程式裡面傳的參數。

export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12
export GLUE_DIR=/path/to/glue

python run_classifier.py \
  --task_name=MRPC \
  --do_train=true \
  --do_eval=true \
  --data_dir=$GLUE_DIR/MRPC \
  --vocab_file=$BERT_BASE_DIR/vocab.txt \
  --bert_config_file=$BERT_BASE_DIR/bert_config.json \
  --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
  --max_seq_length=128 \
  --train_batch_size=32 \
  --learning_rate=2e-5 \
  --num_train_epochs=3.0 \
  --output_dir=/tmp/mrpc_output/      

這是官方給的示例,這個将兩個檔案夾加入了系統路徑,本人Ubuntu18.04加了好像也找不到,是以建議将那些檔案路徑改為絕對路徑。

task_name --> 這次任務的名稱
do_train --> 是否做fine-tune
do_eval --> 是否交叉驗證
do_predict --> 是否做預測
data_dir --> 資料集的位置
vocab_dir --> 詞表的位置(一般bert模型下好就能找到) 
bert_config --> bert模型參數設定
init_checkpoint --> 預訓練好的模型
max_seq_length --> 一個序列的最大長度
output_dir --> 結果輸出檔案(包括日志檔案)
do_lower_case --> 是否小寫處理(針對英文)

其他的字面意思      

跑不動? 

有些時候發現跑demo的時候會出現各種問題,這裡簡單彙總一下 

1. No such file or directory! 這個意思是沒找到,你需要確定你上面模型和資料檔案的路徑填正确就可解決

2. Memory Limit

超詳細的 Bert 文本分類源碼解讀 | 附源碼

因為bert參數量巨大,模型複雜,如果GPU顯存不夠是帶不動的,就會出現上圖的情形不斷跳出。

解決方法

  • 把batch_size,max_seq_length,num_epochs改小一點
  • 把do_train直接false掉
  • 使用優化bert模型,如Albert,FastTransformer

經過本人實證,把參數适當改小參數,如果還是不行直接不做fine-tune就好,這對迅速跑通demo的人來說最有效。

這是很多時候我們自己跑别的任務最為重要的一章,因為很多時候模型并不需要你大改,人家都已經給你訓練好了,你在它的基礎上進行優化就好了。而資料如何讀入以及進行處理,讓模型可以訓練是至關重要的一步。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

簡單介紹一下我們的資料,第一列為Quality,意思是前後兩個句子能不能比對得起來,如果可以即為1,反之為0。第二,三兩列為ID,沒什麼意義,最後兩列分别代表兩個句子。

接下來我們看到DataProcessor類,(有些類的作用僅僅是初始化參數,本文不作講解)。這個類是父類(超類),後面不同任務資料處理類都會繼承自它。它裡面定義了一個讀取tsv檔案的方法。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

首先會将每一列的内容讀取到一個清單裡面,然後将每一行的内容作為一個小清單作為元素加到大清單裡面。

因為我們的資料集為MRPC,我們直接跳到MrpcProcessor類就好,它是繼承自DataProcessor。

這裡簡要介紹一下os.path.join。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

我們不是一共有三個資料集,train,dev以及test嘛,data_dir我們給的是它們的父目錄,我們如何能讀取到它們呢?以train為例,是不是得"path/train.tsv",這個時候,os.path.join就可以把兩者拼接起來。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

這個意思是任務的标簽,我們的任務是二分類,自然為0&1。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

examples最終是清單,第一個元素為清單,内容圖中已有。

讀取資料之後,接下來我們需要對詞進行切分以及簡單的編碼處理

超詳細的 Bert 文本分類源碼解讀 | 附源碼

label_list前面對資料進行處理的類裡有get_labels參數,傳回的是一個清單,如["0","1"]。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

想要切分資料,首先得讀取詞表吧,代碼裡面一開始創造一個OrderedDict,這個是為什麼呢? 

在python 3.5的時候,當你想要周遊鍵值對的時候它是任意傳回的,換句話說它并不關心鍵值對的儲存順序,而隻是跟蹤鍵和值的關聯程度,會出現無序情況。而OrderedDict可以解決無序情況,它内部維護着一個根據插入順序排序的雙向連結清單,另外,對一個已經存在的鍵的重複複制不會改變鍵的順序。 

需要注意,OrderedDict的大小為一般字典的兩倍,尤其當儲存的東西大了起來的時候,需要慎重權衡。 

但是到了python 3.6,字典已經就變成有序的了,為什麼還用OrderedDict,我就有些疑惑了。如果說OrderedDict排序用得到,可是普通dict也能勝任,為什麼非要用OrderedDict呢?

超詳細的 Bert 文本分類源碼解讀 | 附源碼

在tokenization.py檔案中提供了三種切分,分别是BasicTokenizer,WordpieceTokenizer和FullTokenizer,下面具體介紹一下這三者。

在tokenization.py檔案中遍布convert_to_unicode,這是用來轉換為unicode編碼,一般來說,輸入輸出不會有變化。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

這個方法是用來替換不合法字元以及多餘的空格,比如\t,\n會被替換為兩個标準空格。接下來會有一個_tokenize_chinese_chars方法,這個是對中文進行編碼,我們首先要判斷一下是否是中文字元吧,_is_chinese_char方法會進行一個判斷。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

如果是中文字元,_tokenize_chinese_chars會将中文字元旁邊都加上空格,圖中我也有引例注釋。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

whitespace_tokenize會進行按空格切分。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

_run_strip_accents會将變音字元替換掉,如résumé中的é會被替換為e。

接下來進行标點字元切分,前提是判斷是否是标點吧,_is_punctuation履行了這個職責,這裡不再多說。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

以上便是BasicTokenizer的内容了。

接下來是WordpieceTokenizer了,其實這個詞切分是針對英文單詞的,因為漢字每個字已經是最小的結構,不能進行切分了。而英文還可以進行切分,英文有不同語态,如loved,loves,loving等等,這個時候WordpieceTokenizer就能發揮作用了。

  • 周遊一個英文單詞裡面的小結構,如果發現在詞表裡找到,就把這個切掉
  • 對未被切分的部分繼續進行步驟一,直至所有都被切分幹淨,注意除了第一個,其他的前面都要加上"##"

下面有個gif可以直覺顯示,來源:

https://alanlee.fun/2019/10/16/bert-tokenizer/
超詳細的 Bert 文本分類源碼解讀 | 附源碼

最後是FullTokenizer,這個是兩者的內建版,先進行BasicTokenizer,後進行WordpieceTokenizer。當然了,對于中文,就沒必要跑WordpieceTokenizer。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

下面簡單提一下convert_by_vocab,這裡是将具體的内容轉換為索引。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

以上就是切分了。

剛剛對資料進行了切分,接下來我們跳到函數convert_single_example,進一步進行詞向量編碼。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

這裡是初始化一個例子。input_ids 是等會把一個一個詞轉換為詞表的索引;segment_ids代表是前一句話(0)還是後一句話(1),因為這還未執行個體化,是以is_real_example為false。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

此處tokenizer.tokenize是FullTokenizer的方法。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

不同的任務可能含有的句子不一樣,上面代碼的意思就是若b不為空,那麼max_length = 總長度 - 3,原因注釋已有;若b為空,則就需要減去2即可。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

_truncate_seq_pair進行一個截斷操作,裡面用了pop(),這個是清單方法,把清單最後一個取出來,英文注釋也說了為什麼沒有按照比例截斷,若一個序列很短,那按比例截斷會流失資訊較多,因為比例是長短序列通用的。同時,_truncate_seq_pair還保證了a,b長度一緻。若b為空,a則不需要調用這個方法,直接清單方法取就好。

我們不是說需要在開頭添加[CLS],句子分割處和結尾添加[SEP]嘛(本次任務a,b均不為空),剛剛隻是進行了一個切分和截斷操作。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

tokens是我們用來放序列轉換為編碼的新清單,segment_ids用來差別是第一句還是第二句。這段代碼大意就是在開頭和結尾處加入[CLS],[SEP],因為是a是以都是第一句,segment_ids就都為0,同時[CLS]和[SEP]也都被當做是a的部分,編碼為0。下面關于b的同理。

接下來再把具體内容轉換為索引。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

我們一開始的參數不是有max_seq_length嘛,這個代表一整個序列的最大長度(a,b拼接的),但是很多時候我們的總序列長度不會達到最大長度,但是我們又要保證所有輸入序列長度一緻,即為最大序列長度。是以我們需要對剩下的部分,即沒有内容的部分進行填充(Padding),但填充的時候有個問題,一般我們都會添0,但做self-attention的時候(如果還不了解自注意力,可以去首頁看看我寫的Transformer的論文解讀),每一個詞要跟句子裡面所有的詞做内積,但是0是我們人為填充進去的,它不代表任何意義,然而,做自注意力的時候還是要跟它做内積,是不是不太合理呀?

于是就有了MASK機制,什麼意思呢?我們把機器需要看,需要做自注意力的保留,不要看的MASK掉,這樣做自注意力的時候就不會出岔子。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

同時,隻要沒達到最大長度,就全部補零。

這個的剩餘部分tf.logging是日志,不用管,這個convert_single_example最終傳回的是feature,feature包含什麼已經具體闡述過了。

因為用TFRecord讀取檔案比較友善快捷,需要轉換一下檔案格式。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

前半部分是examples寫入,examples是來自上圖方法。features是來自上面剛講過的convert_single_example方法。

需要注意的是這份run_classifier.py人家谷歌是用TPU跑的,是以會有TPU部分代碼,一般我們隻用GPU,是以TPU部分不需要關注,一般TPU都會出現TPUEstimator。

接下來,是構模組化型篇,是整個代碼中最重要的一部分。接下來我将用代碼介紹一下transformer模型的架構。

找到modeling.py檔案,這是模型檔案。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

首先是BertConfig的類,這裡自定義了一些參數及數值。

vocab_size --> 詞表的大小,用别人的詞表,這個參數已經固定
hidden_size --> 隐層神經元個數
num_hidden_layers --> encoder的層數
num_attention_heads -->注意力頭的個數
intermediate_size --> 中間層神經元個數
hidden_act --> 隐層激活函數
hidden_dropout_prob --> 在全連接配接層中實施Dropout,被去掉的機率
attention_probs_dropout_prob --> 注意力層dropout比例
max_position_embeddings --> 最大位置數目
initializer_range --> truncated_normal_initializer的stdev,用來初始化權重參數,從普通正态分布中标準差為0.02的分布中取樣出一部分參數,作為初始化權重      
超詳細的 Bert 文本分類源碼解讀 | 附源碼

詞向量拼接 

接下來正式進入Embedding層的操作,最終傳到注意力層的其實是原始token_ids,token_type_ids以及positional embedding拼接起來的。 

token_ids編碼 

首先是token_ids的操作,先來看一下embedding_lookup方法。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

這是它的參數,大部分英文注釋已有,需要注意的一點是input_ids的shape必須為[batch_size,max_seq_length]。

接下來進行擴維。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

等會我們需要在embedding_table裡面查找,這裡先建構一個[vocab_size,embedding_size]的table。需要注意的是vocab_size 和 embedding_size 都是固定好的,訓練的時候不能亂改。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

之後我們對input_ids進行降維,貌似這樣可以加速。one_hot_embedding一般為false,這是對TPU加速用的。接下來在embedding_table裡面進行查找。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

然後我們把output reshape一下。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

這就是token的編碼了。

進行位置編碼之前,我們首先進行對token_type_ids的編碼(判斷是哪一句)。

首先建立token_type_table。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

然後進行一個token_type_embedding,matul是矩陣相乘

超詳細的 Bert 文本分類源碼解讀 | 附源碼

做好相乘之後,我們需要把token_type_embedding的shape還原,因為等會要将token_type_ids與詞編碼相加。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

首先我們先創造大量的位置,max_position_embeddings是官方給定的參數,不能修改。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

我們創造了這麼多的位置,最終不一定用的完,為了更快速的訓練,我們一般做切片處理,隻要到我的max_seq_length還有位置就好,後面都可以不要。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

前面要把token_type_embeddings加到input_ids的編碼中,進行了同次元處理,這裡對于位置編碼也一樣,不然最後相加不了。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

至此,Embedding層就結束了。Transformer論文不是說了嘛,在加入位置編碼之前會進行一個Dropout操作

超詳細的 Bert 文本分類源碼解讀 | 附源碼

多頭機制

接下來來到整個transformer模型的精華部分,即為多頭注意力機制。

首先來到create_attention_mask_from_input_mask方法,from_seq_length和to_seq_length分别指的是a和b,前面講關于切分的時候已經說了,切分處理會讓a,b長度一緻為max_seq_length。是以這裡兩者長度相等。最後建立了一個shape為(batch_size,from_seq_length,to_seq_length)的MASK。又擴充了一個次元,那這個次元用來幹什麼呢?我們一開始不是說了嗎?自注意的時候需要将填充的部分遮掉,那麼多餘的次元幹的就是這個事。比如我們設定最大長度為8,句子長度為6,那麼有一個次元是[1,1,1,1,1,1,0,0]。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

Q,K,V矩陣

建構首先來到attention_layer方法,q,k,v矩陣的激活函數均為None。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

在進入建構之前,最好先熟悉這5個字母的含義。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

開始建構q矩陣,注意q是由from_tensor,即第一個句子建構的。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

接着建構k和v矩陣,都是從to_tensor建構的。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

接下來會對q,k矩陣進行加速内積處理,不做深入探讨。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

記得我們在transformer裡面需要除以d的次元開根号。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

attention_mask即為上節我們說的MASK,這裡進行拓展一個次元。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

這裡再簡要介紹一下adder。tf.cast方法隻是轉換資料類型,這裡用x代表attention_mask,(1-x)* (-1000)的目的是當attention為1時,即要關注這個,那麼(1-x)就越趨近于0,那麼做softmax,值就越接近于0,類似地,如果attention為0,那麼進過softmax後的值就更接近-1。最後把這個adder加到剛剛我們得到注意力的值,估計這裡會有人搞不懂為什麼怎麼做。 

果關聯度很高,那麼attention_scores就越接近1,越低,越接近0,但是,很可能是我們補零的部分,是以我們需要對這個進行處理,這裡有兩種思路,既然是補零的,我們直接去掉就好;或者這裡谷歌的做法是如果不需要,直接-1,是不是注意力值就趨近于0了,如果需要,加了0本身值不會發生變化。經過谷歌驗證,後者效率更高。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

接下來進行transformer模型建構,不難發現這裡from_tensor和to_tensor一緻,是以是做自注意力。

在bert裡面說過,最後拿出開頭的[CLS]就可以了。這既是get_pooled_output方法的作用。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

最後再連接配接一個全連接配接層,最後就是二分類的任務w * x + b

超詳細的 Bert 文本分類源碼解讀 | 附源碼

model_fn方法是建構的函數之一,一定一定要小心,雖然上面寫着傳回給TPUEstimator,可如果你運作過demo的話,輸出的很多東西都來源于這個方法。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

進入main(_)主方法,需要注意的是,以後我們需要fine-tune,需要把我們自己定義的processor添加進processors。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

确認要訓練之後,會計算需要一共多少步完成,這裡還有個warm-up,意思是一開始呢讓learning rate低一下,等到了warm-up proportion之後再還原。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

終于我們開始構模組化型了

超詳細的 Bert 文本分類源碼解讀 | 附源碼

最終我們建構了estimator用于後期訓練,評估和預測

超詳細的 Bert 文本分類源碼解讀 | 附源碼

這是殘差相連的部分

超詳細的 Bert 文本分類源碼解讀 | 附源碼

還有一點就是記得在transformer中講過我們會連兩層全連接配接層,一層升維,另一層降維。

超詳細的 Bert 文本分類源碼解讀 | 附源碼

接下來進行降維

超詳細的 Bert 文本分類源碼解讀 | 附源碼

覺得寫的好,不妨去github上給我star,裡面有很多比這還要棒的解析: