天天看點

基于tensorflow+RNN的MNIST資料集手寫數字分類

2018年9月25日筆記

tensorflow是谷歌google的深度學習架構,tensor中文叫做張量,flow叫做流。 RNN是recurrent neural network的簡稱,中文叫做循環神經網絡。 MNIST是Mixed National Institue of Standards and Technology database的簡稱,中文叫做美國國家标準與技術研究所資料庫。 此文在上一篇文章《基于tensorflow+DNN的MNIST資料集手寫數字分類預測》的基礎上修改模型為循環神經網絡模型,模型準确率從98%提升到98.5%,錯誤率減少了25% 《基于tensorflow+DNN的MNIST資料集手寫數字分類預測》文章連結:https://www.jianshu.com/p/9a4ae5655ca6

0.程式設計環境

作業系統:Win10 tensorflow版本:1.6 tensorboard版本:1.6 python版本:3.6

1.緻謝聲明

本文是作者學習《周莫煩tensorflow視訊教程》的成果,感激前輩; 視訊連結:https://morvanzhou.github.io/tutorials/machine-learning/tensorflow/

2.配置環境

使用循環神經網絡模型要求有較高的機器配置,如果使用CPU版tensorflow會花費大量時間。 讀者在有nvidia顯示卡的情況下,安裝GPU版tensorflow會提高計算速度50倍。 安裝教程連結:https://blog.csdn.net/qq_36556893/article/details/79433298 如果沒有nvidia顯示卡,但有visa信用卡,請閱讀我的另一篇文章《在谷歌雲伺服器上搭建深度學習平台》,連結:https://www.jianshu.com/p/893d622d1b5a

3.下載下傳并解壓資料集

MNIST資料集下載下傳連結: https://pan.baidu.com/s/1fPbgMqsEvk2WyM9hy5Em6w 密碼: wa9p 下載下傳壓縮檔案MNIST_data.rar完成後,選擇解壓到目前檔案夾,不要選擇解壓到MNIST_data。 檔案夾結構如下圖所示:

基于tensorflow+RNN的MNIST資料集手寫數字分類

image.png

4.完整代碼

此章給讀者能夠直接運作的完整代碼,使讀者有程式設計結果的感性認識。 如果下面一段代碼運作成功,則說明安裝tensorflow環境成功。 想要了解代碼的具體實作細節,請閱讀後面的章節。 完整代碼中定義函數RNN使代碼簡潔,但在後面章節中為了易于讀者了解,本文作者在第6章搭建神經網絡将此部分函數改寫為隻針對于該題的順序執行代碼。

import warnings
warnings.filterwarnings('ignore')
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

tf.reset_default_graph()
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
learing_rate = 0.001
batch_size =100
n_steps = 28
n_inputs = 28
n_hidden_units = 128
n_classes = 10
X_holder = tf.placeholder(tf.float32)
Y_holder = tf.placeholder(tf.float32)

def RNN(X_holder):
    reshape_X = tf.reshape(X_holder, [-1, n_steps, n_inputs])
    lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden_units)
    outputs, states = tf.nn.dynamic_rnn(lstm_cell, reshape_X, dtype=tf.float32)
    cell_list = tf.unstack(tf.transpose(outputs, [1, 0, 2]))
    last_cell = cell_list[-1]
    Weights = tf.Variable(tf.truncated_normal([n_hidden_units, n_classes]))
    biases = tf.Variable(tf.constant(0.1, shape=[n_classes]))
    predict_Y = tf.matmul(last_cell, Weights) + biases
    return predict_Y
predict_Y = RNN(X_holder)
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=predict_Y, labels=Y_holder))
optimizer = tf.train.AdamOptimizer(learing_rate)
train = optimizer.minimize(loss)

init = tf.global_variables_initializer()
session = tf.Session()
session.run(init)

isCorrect = tf.equal(tf.argmax(predict_Y, 1), tf.argmax(Y_holder, 1))
accuracy = tf.reduce_mean(tf.cast(isCorrect, tf.float32))
for i in range(1000):
    X, Y = mnist.train.next_batch(batch_size)
    session.run(train, feed_dict={X_holder:X, Y_holder:Y})
    step = i + 1
    if step % 100 == 0:
        test_X, test_Y = mnist.train.next_batch(3000)
        test_accuracy = session.run(accuracy, feed_dict={X_holder:test_X, Y_holder:test_Y})
        print(step, "{:.4f}".format(test_accuracy))           

複制

上面一段代碼的運作結果如下:

Extracting MNIST_data\train-images-idx3-ubyte.gz Extracting MNIST_data\train-labels-idx1-ubyte.gz Extracting MNIST_data\t10k-images-idx3-ubyte.gz Extracting MNIST_data\t10k-labels-idx1-ubyte.gz 100 0.852 200 0.888 300 0.939 400 0.930 500 0.946 600 0.959 700 0.953 800 0.948 900 0.956 1000 0.958

5.資料準備

第1行代碼導入庫warnings; 第2行代碼表示不列印警告資訊; 第3行代碼導入庫tensorflow,取别名tf; 第4行代碼從tensorflow.examples.tutorials.mnist庫中導入input_data方法; 第6行代碼表示重置tensorflow圖 第7行代碼加載資料庫MNIST指派給變量mnist; 第8-13行代碼定義超參數學習率learning_rate、批量大小batch_size、步數n_steps、輸入層大小n_inputs、隐藏層大小n_hidden_units、輸出層大小n_classes。 第14、15行代碼中placeholder中文叫做占位符,将每次訓練的特征矩陣X和預測目标值Y指派給變量X_holder和Y_holder。

import warnings
warnings.filterwarnings('ignore')
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

tf.reset_default_graph()
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
learing_rate = 0.001
batch_size =100
n_steps = 28
n_inputs = 28
n_hidden_units = 128
n_classes = 10
X_holder = tf.placeholder(tf.float32)
Y_holder = tf.placeholder(tf.float32)           

複制

6.搭建神經網絡

本文作者将此章中使用tensorflow庫的所有方法的API連結總結成下表,通路需要V**。

方法 連結
tf.reshape https://www.tensorflow.org/api_docs/python/tf/manip/reshape
tf.nn.rnn_cell.LSTMCell https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/BasicLSTMCell
tf.nn.dynamic_rnn https://www.tensorflow.org/api_docs/python/tf/nn/dynamic_rnn
tf.transpose https://www.tensorflow.org/api_docs/python/tf/transpose
tf.unstack https://www.tensorflow.org/api_docs/python/tf/unstack
tf.Variable https://www.tensorflow.org/api_docs/python/tf/Variable
tf.truncated_normal https://www.tensorflow.org/api_docs/python/tf/truncated_normal
tf.matmul https://www.tensorflow.org/api_docs/python/tf/matmul
tf.reduce_mean https://www.tensorflow.org/api_docs/python/tf/reduce_mean
tf.nn.softmax_cross_entropy_with_logits https://www.tensorflow.org/api_docs/python/tf/nn/softmax_cross_entropy_with_logits
tf.train.AdamOptimizer https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer

第1行代碼reshape中文叫做重塑形狀,将輸入資料X_holder重塑形狀為模型需要的; 第2行代碼調用tf.nn.rnn_cell.LSTMCell方法執行個體化LSTM細胞對象; 第3行代碼調用tf.nn.dynamic_rnn方法執行個體化rnn模型對象; 第4、5行代碼取得rnn模型中最後一個細胞的數值; 第6、7行代碼定義在訓練過程會更新的權重Weights、偏置biases; 第8行代碼表示

xW+b

的計算結果指派給變量predict_Y,即預測值; 第9行代碼表示交叉熵作為損失函數loss; 第10行代碼表示AdamOptimizer作為優化器optimizer; 第11行代碼定義訓練過程,即使用優化器optimizer最小化損失函數loss。

reshape_X = tf.reshape(X_holder, [-1, n_steps, n_inputs])
lstm_cell = tf.nn.rnn_cell.LSTMCell(n_hidden_units)
outputs, state = tf.nn.dynamic_rnn(lstm_cell, reshape_X, dtype=tf.float32)
cell_list = tf.unstack(tf.transpose(outputs, [1, 0, 2]))
last_cell = cell_list[-1]
Weights = tf.Variable(tf.truncated_normal([n_hidden_units, n_classes]))
biases = tf.Variable(tf.constant(0.1, shape=[n_classes]))
predict_Y = tf.matmul(last_cell, Weights) + biases
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=predict_Y, labels=Y_holder))
optimizer = tf.train.AdamOptimizer(learing_rate)
train = optimizer.minimize(loss)           

複制

7.參數初始化

對于神經網絡模型,重要是其中的W、b這兩個參數。 開始神經網絡模型訓練之前,這兩個變量需要初始化。 第1行代碼調用tf.global_variables_initializer執行個體化tensorflow中的Operation對象。

基于tensorflow+RNN的MNIST資料集手寫數字分類

image.png

第2行代碼調用tf.Session方法執行個體化會話對象; 第3行代碼調用tf.Session對象的run方法做變量初始化。

init = tf.global_variables_initializer()
session = tf.Session()
session.run(init)           

複制

8.模型訓練

第1行代碼tf.argmax方法中的第2個參數為1,即求出矩陣中每1行中最大數的索引; 如果argmax方法中的第1個參數為0,即求出矩陣中每1列最大數的索引; tf.equal方法可以比較兩個向量的在每個元素上是否相同,傳回結果為向量,向量中元素的資料類型為布爾bool; 第2行代碼

isCorrect = tf.equal(tf.argmax(predict_Y, 1), tf.argmax(Y_holder, 1))
accuracy = tf.reduce_mean(tf.cast(isCorrect, tf.float32))
for i in range(1000):
    X, Y = mnist.train.next_batch(batch_size)
    session.run(train, feed_dict={X_holder:X, Y_holder:Y})
    step = i + 1
    if step % 100 == 0:
        test_X, test_Y = mnist.test.next_batch(10000)
        test_accuracy = session.run(accuracy, feed_dict={X_holder:test_X, Y_holder:test_Y})
        print(step, "{:.4f}".format(test_accuracy))            

複制

上面一段代碼的運作結果如下:

100 0.8272 200 0.9071 300 0.9334 400 0.9441 500 0.9459 600 0.9585 700 0.9548 800 0.9664 900 0.9654 1000 0.9671

文章篇幅所限,隻列印檢視1000次訓練的結果,訓練5000次即可達到98.5%的準确率。