天天看點

Tensorflow2.0-mnist手寫數字識别示例

Tensorflow2.0-mnist手寫數字識别示例

     

      讀書不覺春已深,一寸光陰一寸金。

簡介:通過CNN 卷積神經網絡訓練後識别出手寫圖檔,測試圖檔mnist資料集中的0、1、2、4。

Tensorflow2.0-mnist手寫數字識别示例
Tensorflow2.0-mnist手寫數字識别示例
Tensorflow2.0-mnist手寫數字識别示例
Tensorflow2.0-mnist手寫數字識别示例

一、mnist資料集準備

     雖然可以通過代碼自動下載下傳資料集,但是mnist 資料集國内下載下傳不穩定,會出現【Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz】的情況,代碼從定義目錄data_set_tf3 中未擷取到mnist 資料集就會自動下載下傳,但下載下傳時間比較久,還是提前準備好。

Downloading mnist data from https

Tensorflow2.0-mnist手寫數字識别示例
mnist資料集下載下傳位址
Tensorflow2.0-mnist手寫數字識别示例

mnist資料集官網如上,下載下傳下面四個東西就可以了,圖中标紅的兩個images和lables。

Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解壓後 47 MB, 包含 60,000 個樣本)

Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解壓後 60 KB, 包含 60,000 個标簽)

Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解壓後 7.8 MB, 包含 10,000 個樣本)

Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解壓後 10 KB, 包含 10,000 個标簽)

      MNIST 資料集來自美國國家标準與技術研究所,  訓練集 (training set) 由來自 250 個不同人手寫的數字構成, 其中 50% 是高中學生, 50% 來自人口普查局 的從業人員;測試集(test set) 也是同樣比例的手寫數字資料;可以建立一個檔案夾 – mnist, 将資料集下載下傳到 mnist 解壓即可。

mnist資料集整合

Tensorflow2.0-mnist手寫數字識别示例

三、圖檔訓練

train.py 訓練代碼如下:

Tensorflow2.0-mnist手寫數字識别示例
Tensorflow2.0-mnist手寫數字識别示例

1 import os
 2 import tensorflow as tf
 3 from tensorflow.keras import datasets, layers, models
 4 
 5 '''
 6 python 3.7、3.9
 7 tensorflow 2.0.0b0
 8 '''
 9 
10 # 模型定義的前半部分主要使用Keras.layers 提供的Conv2D(卷積)與MaxPooling2D(池化)函數。
11 # CNN的輸入是次元為(image_height, image_width, color_channels)的張量,
12 # mnist資料集是黑白的,是以隻有一個color_channels 顔色通道;一般的彩色圖檔有3個(R, G, B),
13 # 也有4個通道的(R, G, B, A),A代表透明度;
14 # 對于mnist資料集,輸入的張量次元為(28, 28, 1),通過參數input_shapa 傳給網絡的第一層
15 # CNN模型處理:
16 class CNN(object):
17     def __init__(self):
18         model = models.Sequential()
19         # 第1層卷積,卷積核大小為3*3,32個,28*28為待訓練圖檔的大小
20         model.add(layers.Conv2D(
21             32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
22         model.add(layers.MaxPooling2D((2, 2)))
23         # 第2層卷積,卷積核大小為3*3,64個
24         model.add(layers.Conv2D(64, (3, 3), activation='relu'))  # 使用神經網絡中激活函數ReLu
25         model.add(layers.MaxPooling2D((2, 2)))
26         # 第3層卷積,卷積核大小為3*3,64個
27         model.add(layers.Conv2D(64, (3, 3), activation='relu'))
28 
29         model.add(layers.Flatten())
30         model.add(layers.Dense(64, activation='relu'))
31         model.add(layers.Dense(10, activation='softmax'))
32         # Flatten層用來将輸入“壓平”,即把多元的輸入一維化,常用在從卷積層到全連接配接層的過渡。Flatten不影響batch的大小
33         # dense :全連接配接層相當于添加一個層
34         # softmax用于多分類過程中,它将多個神經元的輸出,映射到(0,1)區間内,可以看成機率來了解,進而來進行多分類!
35         model.summary()  # 輸出模型各層的參數狀況
36 
37         self.model = model
38 
39 
40 # mnist資料集預處理
41 class DataSource(object):
42     def __init__(self):
43         # mnist資料集存儲的位置,如果不存在将自動下載下傳
44         data_path = os.path.abspath(os.path.dirname(
45             __file__)) + '/../data_set_tf2/mnist.npz'
46         (train_images, train_labels), (test_images,
47                                        test_labels) = datasets.mnist.load_data(path=data_path)
48         # 6萬張訓練圖檔,1萬張測試圖檔
49         train_images = train_images.reshape((60000, 28, 28, 1))
50         test_images = test_images.reshape((10000, 28, 28, 1))
51         # 像素值映射到 0 - 1 之間
52         train_images, test_images = train_images / 255.0, test_images / 255.0
53 
54         self.train_images, self.train_labels = train_images, train_labels
55         self.test_images, self.test_labels = test_images, test_labels
56 
57 
58 # 開始訓練并儲存訓練結果
59 class Train:
60     def __init__(self):
61         self.cnn = CNN()
62         self.data = DataSource()
63 
64     def train(self):
65         check_path = './ckpt/cp-{epoch:04d}.ckpt'
66         # period 每隔5epoch儲存一次
67         save_model_cb = tf.keras.callbacks.ModelCheckpoint(
68             check_path, save_weights_only=True, verbose=1, period=5)
69 
70         self.cnn.model.compile(optimizer='adam',
71                                loss='sparse_categorical_crossentropy',
72                                metrics=['accuracy'])
73         self.cnn.model.fit(self.data.train_images, self.data.train_labels,
74                            epochs=5, callbacks=[save_model_cb])
75 
76         test_loss, test_acc = self.cnn.model.evaluate(
77             self.data.test_images, self.data.test_labels)
78         print("準确率: %.4f,共測試了%d張圖檔 " % (test_acc, len(self.data.test_labels)))
79 
80 
81 if __name__ == "__main__":
82     app = Train()
83     app.train()      

View Code~拍一拍小輪胎

 mnist手寫數字識别訓練了四分鐘左右,準确率高達0.9902,下面的視訊隻截取了訓練的前十秒。

 mnist手寫數字識别訓練視訊

model.summary()列印定義的模型結構

Tensorflow2.0-mnist手寫數字識别示例

CNN定義的模型結構

Tensorflow2.0-mnist手寫數字識别示例
Tensorflow2.0-mnist手寫數字識别示例
1 Model: "sequential"
 2 _________________________________________________________________
 3 Layer (type)                 Output Shape              Param #   
 4 =================================================================
 5 conv2d (Conv2D)              (None, 26, 26, 32)        320       
 6 _________________________________________________________________
 7 max_pooling2d (MaxPooling2D) (None, 13, 13, 32)        0         
 8 _________________________________________________________________
 9 conv2d_1 (Conv2D)            (None, 11, 11, 64)        18496     
10 _________________________________________________________________
11 max_pooling2d_1 (MaxPooling2 (None, 5, 5, 64)          0         
12 _________________________________________________________________
13 conv2d_2 (Conv2D)            (None, 3, 3, 64)          36928     
14 _________________________________________________________________
15 flatten (Flatten)            (None, 576)               0         
16 _________________________________________________________________
17 dense (Dense)                (None, 64)                36928     
18 _________________________________________________________________
19 dense_1 (Dense)              (None, 10)                650       
20 =================================================================
21 Total params: 93,322
22 Trainable params: 93,322
23 Non-trainable params: 0
24 _________________________________________________________________      

View Code

      我們可以看到,每一個Conv2D 和MaxPooling2D 層的輸出都是一個三維的張量(height, width, channels),height 和width 會逐漸地變小;輸出的channel 的個數,是由第一個參數(例如,32或64)控制的,随着height 和width 的變小,channel可以變大(從算力的角度)。

      模型的後半部分,是定義張量的輸出。layers.Flatten 會将三維的張量轉為一維的向量,展開前張量的次元是(3, 3, 64) ,轉為一維(576)【3*3*64】的向量後,緊接着使用layers.Dense 層,構造了2層全連接配接層,逐漸地将一維向量的位數從576變為64,再變為10。

      後半部分相當于是建構了一個隐藏層為64,輸入層為576,輸出層為10的普通的神經網絡。最後一層的激活函數是softmax,10位恰好可以表達0-9十個數字。最大值的下标即可代表對應的數字,使用numpy 的argmax() 方法擷取最大值下标,很容易計算得到預測值。

train.py運作結果

Tensorflow2.0-mnist手寫數字識别示例

      可以看到,在第一輪訓練後,識别準确率達到了0.9536,五輪訓練之後,使用測試集驗證,準确率達到了0.9902。在第五輪時,模型參數成功儲存在了./ckpt/cp-0005.ckpt,而且此時準确率為更高的0.9940,是以也并不是訓練時間次數越久越好,過猶不及。可以加載儲存的模型參數,恢複整個卷積神經網絡,進行真實圖檔的預測。

儲存訓練模型參數

Tensorflow2.0-mnist手寫數字識别示例

四、圖檔預測

predict.py代碼如下:

Tensorflow2.0-mnist手寫數字識别示例
Tensorflow2.0-mnist手寫數字識别示例
1 import tensorflow as tf
 2 from PIL import Image
 3 import numpy as np
 4 
 5 from mnist.v4_cnn.train import CNN
 6 
 7 '''
 8 python 3.7 3.9
 9 tensorflow 2.0.0b0
10 pillow(PIL) 4.3.0
11 '''
12 
13 
14 class Predict(object):
15     def __init__(self):
16         latest = tf.train.latest_checkpoint('./ckpt')
17         self.cnn = CNN()
18         # 恢複網絡權重
19         self.cnn.model.load_weights(latest)
20 
21     def predict(self, image_path):
22         # 以黑白方式讀取圖檔
23         img = Image.open(image_path).convert('L')
24         img = np.reshape(img, (28, 28, 1)) / 255.
25         x = np.array([1 - img])
26 
27         # API refer: https://keras.io/models/model/
28         y = self.cnn.model.predict(x)
29 
30         # 因為x隻傳入了一張圖檔,取y[0]即可
31         # np.argmax()取得最大值的下标,即代表的數字
32         print(image_path)
33         print(y[0])
34         print('        -> Predict picture number is: ', np.argmax(y[0]))
35 
36 
37 if __name__ == "__main__":
38     app = Predict()
39     app.predict('../test_images/0.png')
40     app.predict('../test_images/1.png')
41     app.predict('../test_images/4.png')
42     app.predict('../test_images/2.png')      

預測結果

Tensorflow2.0-mnist手寫數字識别示例

 預測結果:

Tensorflow2.0-mnist手寫數字識别示例
Tensorflow2.0-mnist手寫數字識别示例
1 ../test_images/0.png
 2 [9.9999774e-01 2.6819215e-08 1.2541744e-07 8.7437911e-08 1.0661940e-09
 3  3.3693670e-08 4.6488995e-07 3.5915035e-09 9.8040758e-08 1.4385278e-06]
 4         -> Predict picture number is:  0
 5 ../test_images/1.png
 6 [7.75440956e-09 9.99991298e-01 1.41642090e-07 1.09819875e-10
 7  6.76554646e-06 7.63710162e-09 2.37024622e-08 1.58189516e-06
 8  2.49125264e-07 4.92376007e-09]
 9         -> Predict picture number is:  1
10 ../test_images/4.png
11 [7.03467840e-10 8.20740708e-04 1.11648405e-04 3.93262711e-09
12  9.99048650e-01 1.08713095e-07 4.24647197e-08 1.85665340e-05
13  5.03181887e-08 1.86591734e-07]
14         -> Predict picture number is:  4
15 ../test_images/2.png
16 [1.5828672e-08 1.9245699e-07 9.9999440e-01 5.3448480e-06 1.7397912e-10
17  8.6148493e-13 2.5441890e-10 5.3953073e-08 3.5735226e-08 8.9734775e-11]
18         -> Predict picture number is:  2      

如上,經CNN訓練後通過模型參數準确預測出了0、1、2、4四張手寫圖檔的真實值。

                 

    

 讀書不覺春已深

                            一寸光陰一寸金

繼續閱讀