天天看點

CNN實作mnist手寫數字識别

這個很簡單,直接上代碼,附結果圖檔

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
# YOUR CODE STARTS HERE
class Mycallbacks(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs={}):
        if (logs.get('accuracy') > 0.908):
            print('\nReached 90.8% accuracy so cancelling training!')
            self.model.stop_training=True


callbacks=Mycallbacks()
# YOUR CODE ENDS HERE

mnist = tf.keras.datasets.mnist
(training_images, training_labels), (test_images, test_labels) = mnist.load_data()

# YOUR CODE STARTS HERE

training_images=training_images.reshape(60000,28,28,1)
training_images=training_images/255.0
test_images=test_images.reshape(10000,28,28,1)
test_images=test_images/255.0



# YOUR CODE ENDS HERE

model = tf.keras.models.Sequential([
    # YOUR CODE STARTS HERE
    tf.keras.layers.Conv2D(64,(3,3),(1,1),activation='relu',input_shape=(28,28,1)),
    tf.keras.layers.MaxPooling2D(2,2),
    tf.keras.layers.Dropout(0.5),

    tf.keras.layers.Conv2D(64,(3,3),(1,1),activation='relu',input_shape=(28,28,1)),
    tf.keras.layers.MaxPooling2D(2,2),
    tf.keras.layers.Dropout(0.5),

    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128,activation='relu'),
    tf.keras.layers.Dense(10,activation='softmax')

    # YOUR CODE ENDS HERE
])

# YOUR CODE STARTS HERE
model.summary()
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])
model.fit(training_images,training_labels,epochs=20,callbacks=[callbacks])
model.evaluate(test_images,test_labels)
# YOUR CODE ENDS HERE


plt.rcParams['font.sans-serif']='SimHei' #python顯示正常漢字
result=model.predict(test_images,batch_size=1)
def show_pic(n):
    plt.suptitle('測試結果')
    for i in range(n):
        num = np.random.randint(1, 10000)
        plt.subplot(4, n/4, i + 1)   #将視窗分為1行n列,目前位置為i+1
        plt.axis('off')  #關閉坐标軸
        plt.imshow(test_images[num], cmap='gray')
        plt.title( str(np.argmax(result[num]))) #np.argmax()傳回最大值的索引
    plt.tight_layout(rect=[0, 0, 1, 0.9])
    plt.show()

show_pic(20)
           

結果

CNN實作mnist手寫數字識别
CNN實作mnist手寫數字識别

繼續閱讀