天天看點

10 行代碼,實作手寫數字識别

識别手寫的阿拉伯數字,對于人類來說十分簡單,但是對于程式來說還是有些複雜的。

10 行代碼,實作手寫數字識别

不過随着機器學習技術的普及,使用10幾行代碼,實作一個能夠識别手寫數字的程式,并不是一件難事。這是因為有太多的機器學習模型可以拿來直接用,比如tensorflow、caffe,在python下都有現成的安裝包,寫一個識别數字的程式,10幾行代碼足夠了。

然而我想做的,是不借助任何第三方的庫,從零開始,完全自己實作一個這樣的程式。之是以這麼做,是因為自己動手實作,才能深入了解機器學習的原理。

1 模型實作

1.1 原理

熟悉神經網絡回歸算法的,可以略過這一節了。

學習了一些基本概念,決定使用回歸算法。首先下載下傳了著名的MNIST資料集,這個資料集有60000個訓練樣本,和10000個測試樣本。每個數字圖檔都是2828的灰階圖檔,是以輸入可以認為是一個2828的矩陣,也可以認為是一個28*28=784個像素值。

這裡定義一個模型用于判斷一個圖檔數字,每個模型包括每個輸入的權重,加一個截距,最後再做個歸一。模型的表達式:

Out5= sigmoid(X0W0+ X1W1+……X783*W783+bias)

X0到X783是784個輸入,W0到W783是784個權重,bias是一個常量。sigmoid函數可以将較大範圍的數擠壓到(0,1)區間内,也就是歸一。

例如我們用這一組權重和bias來判斷數字5,期望當圖檔是5時輸出是1,當不是5時輸出是0。然後訓練的過程就是根據每個樣本的輸入,計算Out5的值和正确值(0或1)的差距,然後根據這個差距,調整權重和bias。轉換一下公式,就是在努力使得(Out5-正确值)接近于0,即所謂損失最小。

同理,10個數字就要有10套模型,每個判斷不同的數字。訓練好以後,一個圖檔來了,用這10套模型進行計算,哪個模型計算的結果更接近于1,就認為這個圖檔是哪個數字。

1.2 訓練

按照上面的思路,使用集算器的SPL(結構化處理語言)來編碼實作:

10 行代碼,實作手寫數字識别

不用再找了,訓練模型的所有代碼都在這裡了,沒有用到任何第三方庫,下面解析一下:

A1,用遊标導入MNIST訓練樣本,這個是我轉換過的格式,可以被集算器直接通路;

A2,定義變量:輸入x,權重wei,訓練速度v,等;

A3,B3,初始化10組模型(每組是784個權重+1個bias);

A4,循環取5萬個樣本進行訓練,10模型同時訓練;

B4,取出來label,即這個圖檔是幾;

B5,計算正确的10個輸出,儲存到變量y;

B6,取出來這個圖檔的28*28個像素點作為輸入,C6把每個輸入除以255,這是為了歸一化;

B7,計算X0W0+ X1W1+……X783*W783+bias

B8,計算sigmoid(B7)

B9,計算B8的偏導,或者叫梯度;

B10,C10,根據B9的值,循環調整10個模型的參數;

A11,訓練完畢,把模型儲存到檔案。

1.3 測試

測試一下這個模型的成功率吧,用 SPL 寫了一個測試程式:

10 行代碼,實作手寫數字識别

運作測試,正确率達到了91.1%,我對這個結果是很滿意的,畢竟這隻是一個單層模型,我用TensorFlow的單層模型得到的正确率也是91%多一點。下面解析一下代碼:

A1,導入模型檔案;

A2,把模型提取到變量裡;

A3,計數器初始化(用于計算成功率);

A4,導入MNIST測試樣本,這個檔案格式是我轉換過的;

A5,循環取1萬個樣本進行測試;

   B5,取出來label;

   B6,清空輸入;

B7,取出來這個圖檔的28*28個像素點作為輸入,每個輸入除以255,這是為了歸一化;

B8,計算X0W0+ X1W1+……X783*W783+bias

B9,計算sigmoid(B7)

B10,得到最大值,即最可能的那個數字;

B11,判斷正确測計數器加一;

A12,A13,測試結束,關閉檔案,輸出正确率。

1.4 優化

這裡要說的優化并不是繼續提高正确率,而是提升訓練的速度。想提高正确率的同學可以嘗試一下這幾個手段:

1.       加一個卷積層;

2.       學習速度不要用固定值,而是随着訓練次數遞減;

3.       權重的初始值不要使用全零,使用正态分布;

我認為單純追求正确率的意義不大,因為MNIST資料集有些圖檔本身就有問題,即使人工也不一定能知道寫的是數字幾。我用集算器顯示了幾張出錯的圖檔,都是書寫十分不規範的,下面這個圖檔很難看出來是2。

10 行代碼,實作手寫數字識别

下面說重點,要提高訓練速度,可以使用并行或叢集。使用SPL語言實作并行很簡單,隻要使用fork關鍵字,把上面的代碼稍加處理就可以了。

10 行代碼,實作手寫數字識别

使用了并行之後,訓練的時間減少差不多一半,而代碼并沒有做太多修改。

2 為什麼是 SPL 語言?

使用SPL語言在初期可能會有點不适應,用得多了會覺得越來越友善:

1.       支援集合運算,比如例子裡用到的784個輸入和784個權重的乘法,直接寫一個**就可以了,如果使用Java或者C,還要自己實作。

2.       資料的輸入輸出很友善,可以友善地對檔案讀寫。

3.       調試太友善了,所有變量都直覺可見,這一點比python要好用。

4.       可以單步計算,有了改動不用從頭重來,Java和C做不到這一點,python雖然可以但也不友善,集算器隻要點中相應格執行就可以了。

5.       實作并行和叢集很友善,不需要太多的開發工作量。

6.       支援調用和被調用。集算器可以調用第三方java庫,Java也可以調用集算器的代碼,例如上面的代碼就可以被Java調用,實作一個自動填驗證碼的功能。

這樣的程式設計語言,用在數學計算上,實在是最合适不過了。

154037421300096d9.rar