天天看點

完全解析!Bert & Transformer 閱讀了解源碼詳解

接上一篇: 你所不知道的 Transformer! 超詳細的 Bert 文本分類源碼解讀 | 附源碼 中文情感分類單标簽 參考論文: https://arxiv.org/abs/1706.03762 https://arxiv.org/abs/1810.04805

在本文中,我将以run_squad.py以及SQuAD資料集為例介紹閱讀了解的源碼,官方代碼基于tensorflow-gpu 1.x,若為tensorflow 2.x版本,會有各種錯誤,建議切換版本至1.14。 

當然,注釋好的源代碼在這裡:

https://github.com/sherlcok314159/ML/tree/main/nlp/code 章節

  • Demo傳參
  • 資料篇
    • 番外句子分類
    • 創造執行個體
    • 執行個體轉換
  • 模型構造
  • 寫入預測

python bert/run_squad.py \
  --vocab_file=uncased_L-12_H-768_A-12/vocab.txt \
  --bert_config_file=uncased_L-12_H-768_A-12/bert_config.json \
  --init_checkpoint=uncased_L-12_H-768_A-12/bert_model.ckpt \
  --do_train=True \
  --train_file=SQUAD_DIR/train-v2.0.json \
  --train_batch_size=8 \
  --learning_rate=3e-5 \
  --num_train_epochs=1.0 \
  --max_seq_length=384 \
  --doc_stride=128 \
  --output_dir=/tmp/squad2.0_base/ \
  --version_2_with_negative=True      

閱讀源碼最重要的一點不是拿到就讀,而是跑通源碼裡面的小demo,因為你跑通demo就意味着你對代碼的一些基礎邏輯和參數有了一定的了解。

前面的參數都十分正常,如果不懂,建議看我的文本分類的講解。這裡講一下比較特殊的最後一個參數,我們做的任務是閱讀了解,如果有答案缺失,在SQuAD1.0是不可以的,但是在SQuAD允許,這也就是True的意思。

需要注意,不同人的檔案路徑都是不一樣的,你不能照搬我的,要改成自己的路徑。

其實閱讀了解任務模型是跟文本分類幾乎是一樣的,大的差異在于兩者對于資料的處理,是以本篇文章重點在于如何将原生的資料轉換為閱讀了解任務所能接受的資料,至于模型構造篇,請看文本分類:

https://github.com/sherlcok314159/ML/blob/main/nlp/tasks/text.md

想必很多人看到SquadExample類的_repr_方法都很疑惑,這裡處理好一個example,為什麼後面還要進行處理?看英文注釋會發現這個類其實跟閱讀了解沒關系,它隻是處理之後對于句子分類任務的,自然在run_squad.py裡面沒被調用。_repr_方法隻是在有start_position的時候進行字元串的拼接。

完全解析!Bert & Transformer 閱讀了解源碼詳解

用于訓練的資料集是json檔案,需要用json庫讀入。

訓練集的樣式如下,可見data是最外層的

{
    "data": [
        {
            "title": "University_of_Notre_Dame",
            "paragraphs": [
                {
                    "context": "Architecturally, the school has a Catholic character.",
                    "qas": [
                        {
                            "answers": [
                                {
                                    "answer_start": 515,
                                    "text": "Saint Bernadette Soubirous"
                                }
                            ],
                            "question": "To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?",
                            "id": "5733be284776f41900661182"
                        }
                    ]
                }
            ]
        },
        {
            "title":"...",
            "paragraphs":[
                {
                    "context":"...",
                    "qas":[
                        {
                            "answers":[
                                {
                                    "answer_start":..,
                                    "text":"...",
                                }
                            ],
                            "question":"...",
                            "id":"..."
                        },
                    ]
                }
            ]
        }
    ]
}      
完全解析!Bert & Transformer 閱讀了解源碼詳解

input_data是一個大清單,然後每一個元素樣式如下

{'paragraphs': [{...}, {...}, {...}, {...}, {...}, {...}, {...}, {...}, {...}, ...], 'title': 'University_of_Notre_Dame'}      

is_whitespace方法是用來判斷是否是一個空格,在切分字元然後加入doc_tokens會用到。

完全解析!Bert & Transformer 閱讀了解源碼詳解

然後我們層層剝開,然後周遊context的内容,它是一個字元串,是以周遊的時候會周遊每一個字母,字元會被進行判斷,如果是空格,則加入doc_tokens,char_to_word_offset表示切分後的索引清單,每一個元素表示一個詞有幾個字元組成。

完全解析!Bert & Transformer 閱讀了解源碼詳解

切分後的doc_tokens會去掉空白部分,同時會包括英文逗号。一個單詞會有很多字元,每個字元對應的索引會存在char_to_word_offset,例如,前面都是0,代表這些字元都是第一個單詞的,是以都是0,換句話說就是第一個單詞很長。

doc_tokens = ['Architecturally,', 'the', 'school', 'has', 'a', 'Catholic', 'character.', 'Atop', 'the',"..."]

char_to_word_offset = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...]      

接下來進行qas内容的周遊,每個元素稱為qa,進行id和question内容的配置設定,後面都是初始化一些參數

完全解析!Bert & Transformer 閱讀了解源碼詳解

qa裡面還有一個is_impossible,用于判斷是否有答案

完全解析!Bert & Transformer 閱讀了解源碼詳解

確定有答案之後,剛剛讀入了問題,現在讀入與答案相關的部分,讀入的時候注意start_position和end_position是相對于doc_tokens的

完全解析!Bert & Transformer 閱讀了解源碼詳解

接下來對答案部分進行雙重檢驗,actual_text是根據doc_tokens和始末位置拼接好的内容,然後對orig_answer_text進行空格切分,最後用find方法判斷orig_answer_text是否被包含在actual_text裡面。

完全解析!Bert & Transformer 閱讀了解源碼詳解

這個是針對is_impossible來說的,如果沒有答案,則把始末位置全部變成-1。

完全解析!Bert & Transformer 閱讀了解源碼詳解

然後将example變成SquadExample的執行個體化對象,将example加入大清單——examples并傳回,至此執行個體建立完成。

完全解析!Bert & Transformer 閱讀了解源碼詳解

把json檔案變成執行個體之後,我們還差一步便可以把資料塞進模型進行訓練了,那就是将執行個體轉化為變量。

先對question_text進行簡單的空格切分變為query_tokens

完全解析!Bert & Transformer 閱讀了解源碼詳解

如果問題過長,就進行截斷操作

完全解析!Bert & Transformer 閱讀了解源碼詳解

接下來對doc_tokens進行空格切分以及詞切分,變成all_doc_tokens,需要注意的是orig_to_tok_index代表的是doc_tokens在all_doc_tokens的索引,取最近的一個,而tok_to_orig_index代表的是all_doc_tokens在doc_tokens索引

完全解析!Bert & Transformer 閱讀了解源碼詳解

對tok_start_position和tok_end_position進行初始化,記住,這兩個是相對于all_doc_tokens來說的,一定要與start_position和end_position區分開來,它們是相對于doc_tokens來說的

完全解析!Bert & Transformer 閱讀了解源碼詳解

接下來先介紹_improve_answer_span方法,這個方法是用來處理特殊的情況的,舉個例子,假如說你的文本是"The Japanese electronics industry is the lagest in the world.",你的問題是"What country is the top exporter of electornics?" 那答案其實應該是Japan,可是呢,你用空格和詞切分的時候會發現Japanese已經在詞表中可查,這意味着不會對它進行再切分,會直接将它傳回,這種情況下可能需要這個方法救場。

完全解析!Bert & Transformer 閱讀了解源碼詳解

因為是監督學習,答案已經給出,是以呢,這個方法幹的事情就是詞切分後的tokens進行再一次切分,如果發現切分之後會有更好的答案,就傳回新的始末點,否則就傳回原來的。

對tok_start_position和tok_end_position進行進一步指派

完全解析!Bert & Transformer 閱讀了解源碼詳解

計算max_tokens_for_doc,與文本分類類似,需要減去[CLS]和兩個[SEP]的位置,這裡不同的是還要減去問題的長度,因為這裡算的是文本的長度。 

tokens = [CLS] query tokens [SEP] context [SEP]

完全解析!Bert & Transformer 閱讀了解源碼詳解

很多時候文章長度大于maximum_sequence_length的時候,這個時候我們要對文章進行切片處理,把它按照一定長度進行切分,每一個切片稱為一個doc_span,start代表從哪開始,length代表一個的長度。

完全解析!Bert & Transformer 閱讀了解源碼詳解

doc_spans儲存很多個doc_span。這裡對視窗的長度有所限制,規定了start_offset不能比doc_stride大,這是第二個視窗的起點,從這個角度或許可以了解doc_stride代表平滑的長度。

完全解析!Bert & Transformer 閱讀了解源碼詳解

接下來的操作跟文本分類有些類似,添加[CLS],然後添加問題和[SEP],這些在segment_ids裡面都為0。

完全解析!Bert & Transformer 閱讀了解源碼詳解

下面講_check_is_max_context方法,這個方法是用來判斷某個詞是否具有完備的上下文關系,源代碼給了一個例子: 

Span A: the man went to the 

Span B: to the store and bought 

Span C: and bought a gallon of ... 

那麼對于bought來說,它在Span B和Span C中都有出現,那麼,哪一個上下文關系最全呢?其實我們憑直覺應該可以猜到應該是Span C,因為Span B中bought出現在句末,沒有下文。當然了,我們還是得用公式計算一下

score = min(num_left_context, num_right_context) + 0.01 * doc_span.length      

score_B = min(4, 0) + 0.05 = 0.05 

score_C = min(1,3) + 0.05 = 1.05 

是以,在Span C中,bought的上下文語義最全,最終該方法會傳回True or False,在滑動視窗這個方法中,一個詞很可能出現在多個span裡面,是以用這個方法判斷目前這個詞在目前span裡面是否具有最完整的上下文

完全解析!Bert & Transformer 閱讀了解源碼詳解

回到上面,token_to_orig_map是用來記錄文章部分在all_doc_tokens的索引,而token_is_max_context是記錄文章每一個詞在目前span裡面是否具有最完整的上下文關系,因為一開始隻有一個span,那麼一開始每個詞肯定都是True。split_token_index用于切分成每一個token,這樣可以進行上下文關系判斷,至于後面添[SEP]和segment_ids添1這種操作文本分類也有。

完全解析!Bert & Transformer 閱讀了解源碼詳解

接下來将tokens(精細化切分後的)按照詞表轉化為id,另外若不足,則把0填充進去這種操作也是很常見的。

完全解析!Bert & Transformer 閱讀了解源碼詳解

前面是進行判斷,如果切了之後答案并不在span裡面就直接舍棄,若在裡面,因為一開始all_doc_tokens裡面沒有問題和[CLS],[SEP]時正文的索引是tok_start_position,然後轉換為input_ids又有問題以及[CLS],[SEP],是以要得到正文索引需要跳過它們。

完全解析!Bert & Transformer 閱讀了解源碼詳解

接下來大量的tf.logging隻是寫入日志資訊,同時也是你終端或輸出那裡看到的。

最終用這些參數執行個體化InputFeatures對象,然後不斷重複,每一個feature對應着一個特殊的id,即為unique_id。

模型建構

這裡大緻與文本分類差不多,隻是文本分類在模型裡面直接進行了softmax處理,然後進行最小交叉熵損失,而這次我們沒有直接這樣做,得到了開頭和結尾處的未歸一化的機率logits,之後我們直接傳回。

然後這次我們是在model_fn_builder方法裡面的子方法model_fn裡定義compute_loss,其實這裡也是經過softmax進行歸一化,然後再計算交叉熵損失,最終傳回均方誤差。

完全解析!Bert & Transformer 閱讀了解源碼詳解

然後我們計算開頭和結尾處的損失,總損失為二者和的平均。

最終我們進行優化。

完全解析!Bert & Transformer 閱讀了解源碼詳解

寫入預測 

start_logit & end_logit 代表着未經過softmax的機率,start_logit表示tokens裡面以每一個token作為開頭的機率,後者類似的。還有一對null_start_logit & null_end_logit,它們兩個代表的是SQuAD2.0沒有答案的那些,預設全為0。

首先,簡單介紹一下_get_best_indexes,這個方法是用來輸出由高到低前n_best_size個的機率的索引。

完全解析!Bert & Transformer 閱讀了解源碼詳解

周遊start_indexes,end_indexes(都是分别經過_get_best_indexes得到),對于答案未缺失的,以具體的logit填入,另外,feature_index代表第幾個feature。

完全解析!Bert & Transformer 閱讀了解源碼詳解

如果答案缺失,則全都為0

完全解析!Bert & Transformer 閱讀了解源碼詳解

接下來我們進一步轉換為具體的文本

完全解析!Bert & Transformer 閱讀了解源碼詳解

然後進一步清洗資料

完全解析!Bert & Transformer 閱讀了解源碼詳解

這樣還有個問題,詞切分會自動小寫,與答案還存在一定的偏移,這裡介紹get_final_text方法來解決這一問題,比如: 

pred_text = steve smith 

orig_text = Steve Smith's 

這個方法通俗來講就是獲得orig_text(未經過詞切分)上正确的截取片段。 

然後将其添加到nbest中

完全解析!Bert & Transformer 閱讀了解源碼詳解

同樣會存在沒有答案的情況

完全解析!Bert & Transformer 閱讀了解源碼詳解

接下來會有一個total_scores,它的元素是start_logit和end_logit相加,注意,它們不是數值,是數組,之後就計算total_scores的交叉熵損失作為機率。

完全解析!Bert & Transformer 閱讀了解源碼詳解

剩下的部分跟文本分類差不多,這裡就此略過。