天天看點

OCR性能優化:從認識BiLSTM網絡結構開始

摘要: 想要對OCR進行性能優化,首先要了解清楚待優化的OCR網絡的結構,本文從動機的角度來推演下基于Seq2Seq結構的OCR網絡是如何一步步搭建起來的。

本文分享自華為雲社群《OCR性能優化系列(一):BiLSTM網絡結構概覽》,原文作者:HW007。

OCR是指對圖檔中的印刷體文字進行識别,最近在做OCR模型的性能優化,用 Cuda C 将基于TensorFlow 編寫的OCR網絡重寫了一遍,最終做到了5倍的性能提升。通過這次優化工作對OCR網絡的通用網絡結構和相關的優化方法有較深的認識,計劃在此通過系列博文記錄下來,也作為對自己最近工作的一個總結和學習筆記。

想要對OCR進行性能優化,首先要了解清楚待優化的OCR網絡的結構,在本文中我将嘗試着從動機的角度來推演下基于Seq2Seq結構的OCR網絡是如何一步步搭建起來的。

讀懂此文的前提隻需要了解在矩陣乘法中矩陣的次元變化規律,即n*p的矩陣乘以 p*m 的矩陣等于 n*m 的矩陣。如果知道CNN和RNN網絡的結構,對機器學習模型的構造套路有點了解的話更好。

首先給出從本文要剖析的OCR BILSTM 網絡總體結構如下圖:

OCR性能優化:從認識BiLSTM網絡結構開始

接下來我将從這張圖的右上角(模型的輸出端)向左下角(模型的輸入端)逐漸解釋每一個結構的動機及其作用。

1. 構造最簡單的OCR網絡

首先考慮最簡單情況下的OCR識别場景,假設輸入是隻含有一個文字圖檔,圖檔的高和寬均為32個像素,即32*32的矩陣,為了友善将其拉長便可得到一個 1*1024 的矩陣。在輸出方面,由于文字的特殊性,我們隻能将所有的文字進行标号,最後輸出所識别的文字的編号便好,由此得到我們的輸出是一個 1*1 的矩陣,矩陣元素的内容就是所識别的文字的編号。

怎麼得到這個1*1的矩陣呢?根據機率統計的套路,我們假設全世界存在10000個文字,将其表為1~1000号,那麼這10000個元素都有機率成為我們的輸出,是以我們如果先算出這10000個文字作為該輸入圖檔的識别結果的機率的話,再挑機率最大的那個輸出便可以了。于是問題被轉變成如何從一個 1*1024的矩陣(X)中得到一個 1*10000 的矩陣(Y)。在這裡便可以上機器學習模型結構中最常見的線性假設套路了,假設Y和X是之間是線性相關的,這樣便可得到最簡單且經典的線性模型:Y = AX + B。 其中稱X(次元:1*1024)為輸入,Y(次元:1*10000)為輸出,A和B均為該模型的參數,由矩陣乘法可知A的次元應該是 1024*1000,B的次元應該是 1*10000。至此,隻有X是已知的,我們要計算Y的話還需要知道A和B的具體值。在機器學習的套路中,作為參數的A和B的值在一開始是随機設定的,然後通過喂大量的X及其标準答案Y來讓機器把這兩個參數A、B慢慢地調整到最優值,此過程稱為模型的訓練,喂進去的資料稱為訓練資料。訓練完後,你便可以拿最優的A乘以你的新輸入X在加上最優的B得到相應的Y了,使用argMax操作來挑選Y這1*10000個數中最大的那個數的編号,就是識别出來的文字的編号了。

現在,再回頭去看圖1中右上角的那部分,相信你能看懂兩個黃色的 384*10000 和 1*10000的矩陣的含義了。圖中例子和上段文字描述的例子的差別主要在于圖中的輸入是1張 1*1024的圖檔,上段文字中的是 27張 1*384的圖檔罷了。至此,你已經了解如何構造一個簡單地OCR網絡了。接下來我們就開始對這個簡單地網絡進行優化。

2. 優化政策一:減少計算量

在上面的文字描述的例子中,我們每識别一個文字就要做一次 1*1024和1024*10000的矩陣乘法計算,這裡面計算量太大了,是否有一些計算是備援的呢?熟悉PCA的人應該馬上能想到,其實将 32*32 的文字圖檔拉長為 1*1024的矩陣,這個文字的特征空間是1024維,即便每維的取值隻有0和1兩種,這個特征空間可表示的值都有2^1024種,遠遠大于我們所假設的文字空間中所有文字個數10000個。為此我們可以用PCA或各種降維操作把這個輸入的特征向量降維到小于10000維,比如像圖中的128維。

3. 優化政策二:考慮文字間的相關性

(提醒:在上圖中為了展現出batch Size的次元,是按27張文字圖檔來畫的,下文中的讨論均隻針對1張文字圖檔,是以下文中次元為 1的地方均對應着圖中的27)

也許你已經注意到了,圖中與黃色的384*10000矩陣相乘的“位置圖像特征”的次元沒有直接用一個1*384,而是 1*(128+128+128)。其實這裡隐含着一個優化,這個優化是基于文字間的關聯假設的,簡單地例子就是如果前面一個字是“您”,那其後面跟着的很可能是“好”字,這種文字順序中的統計規律應該是可以用來提升文字圖檔的識别準确率的。那怎麼來實作這個關聯呢?

在圖中我們可以看到左側有一個10000*128的參數矩陣,很容易知道這個參數就像一個資料庫,其儲存了所有10000個文字圖檔經過加工後的特征(所謂加工便是上面提到的降維,原始特征應該是 10000*1024的),照圖中的結構,我需要輸入目前識别的這個字的前一個字的識别結果 (識别工作是一個字接一個字串行地識别出來的)。然後選擇出上個字對應的特征矩陣 1*128,再經過一些加工轉換後當做1*384的輸入中的前1/3部分内容。

同理,1*384裡靠後的兩個1*128又代表什麼含義呢?雖然在句子中,前面一個字對後面一個字的影響很大,即使目前要預測的字在圖檔中很模糊,我也可以根據前面的字将其猜出來。那是否可以根據其前k個字或者後k個字猜出來呢?顯然答案是肯定的。是以靠後的兩個1*128分别代表的是句子圖檔裡文字“從前到後(Forward)”和“從後到前(Backward)”的圖檔特征對目前要識别的字的影響,是以圖中在前面加了個“雙向LSTM網絡”來生成這兩個特征。

至此,改良版的OCR網絡輪廓基本出來了,還有一些細節上的問題需要解決。不知你是否注意到,按上面所述,1*384中包含了3個1*128的特征,分别代表着前一個字對目前字的影響、圖檔中的整個句子中各個文字從前到後(Forward)的排序對目前文字的影響、圖檔中的整個句子中各個文字從後到前(Backward)的排序對目前文字的影響。

但是他們的特征長度都是128!!!一個字是128,一個句子也是128?對于不同的文字圖檔中,句子的長度還可能不一樣,怎麼可能都用一個字的特征長度就表示了呢?

如何表示一個可變長的句子的特征呢?乍一看的确是個很棘手的問題,好在它有一個很粗暴簡單的解決辦法,就是權重求和,又是機率統計裡面的套路,管你有幾種情況,所有的情況的機率求和後都得等于1。看到在這裡不知道是否被震撼到,“變化”和“不變”這樣看起來水火不容的兩個東西就是這麼神奇地共存了,這就是數學的魅力,讓人不禁拍手贊絕!

下圖以一個實際的例子說明這種神奇的方式的運作方式。當我們要對文字片段中的“筷”字進行識别時,盡管改字已近被遮擋了部分,但根據日常生活中的一些經驗知識積累,要對該位置進行補全填空時,我們聯系上下文,把注意力放在上文中的“是中國人”和下文中的“吃飯”上。這個權重系數的機制便是用來實作這種注意力機制的。至于“日常生活中的經驗”這種東西就是由“注意力機制網絡”通過大量的訓練資料來學習得到的。也就是圖1中的那32個alpha的由來。注意力網絡在業界一般由GRU網絡擔任,由于篇幅原因,在此不展開了,下回有機會再細說。看官們隻需知道在圖一的右邊還應該有個“注意力網絡”來輸出32個alpha的值便好。

OCR性能優化:從認識BiLSTM網絡結構開始

點選關注,第一時間了解華為雲新鮮技術~

繼續閱讀