天天看點

《大資料算法學習》(三)MNIST手寫數字識别

一、使用MNIST資料集

    本次學習使用神經網絡識别手寫數字,我們使用的資料集是MNIST資料集,MNIST資料集的長相如下圖所示。

《大資料算法學習》(三)MNIST手寫數字識别

    MNIST資料集是由0 到9 的數字圖像構成。訓練圖像有6 萬張,測試圖像有1 萬張,這些圖像可以用于學習和推理。MNIST資料集的一般使用方法是,先用訓練圖像進行學習,再用學習到的模型度量能在多大程度上對測試圖像進行正确的分類。 MNIST的圖像資料是28 像素 × 28 像素的灰階圖像(1 通道),各個像素的取值在0 到255 之間。每個圖像資料都相應地标有“7”、“2”、“1”等标簽。

    load_mnist函數以“( 訓練圖像, 訓練标簽),( 測試圖像,測試标簽)”的形式傳回讀入的MNIST資料。

def load_mnist():
    train_labels_path = 'train-labels.idx1-ubyte'
    test_labels_path = 't10k-labels.idx1-ubyte'
    train_images_path = 'train-images.idx3-ubyte'
    test_images_path = 't10k-images.idx3-ubyte'

    with open(train_labels_path, 'rb') as lpath:
        magic, n = struct.unpack('>II', lpath.read(8))
        train_labels = np.fromfile(lpath, dtype = np.uint8).astype(np.float)

    with open(train_images_path, 'rb') as ipath:
        magic, num, rows, cols = struct.unpack('>IIII', ipath.read(16))
        loaded = np.fromfile(train_images_path, dtype=np.uint8)
        train_images = loaded[16:].reshape(len(train_labels), 784).astype(np.float)

    with open(test_labels_path, 'rb') as lpath:
        magic, n = struct.unpack('>II', lpath.read(8))
        test_labels = np.fromfile(lpath, dtype = np.uint8).astype(np.float)

    with open(test_images_path, 'rb') as ipath:
        magic, num, rows, cols = struct.unpack('>IIII', ipath.read(16))
        loaded = np.fromfile(test_images_path, dtype=np.uint8)
        test_images = loaded[16:].reshape(len(test_labels), 784)

    return train_images, train_labels, test_images, test_labels

img_train,label_train,img_test,label_test = load_mnist()
           

    MNIST資料集有四個檔案,分别代表列訓練圖像、訓練标簽、測試圖像、測試标簽,下載下傳位址如下:

http://yann.lecun.com/exdb/mnist/

二、神經網絡的推理處理

    神經網絡的輸入層有784 個神經元,輸出層有10 個神經元。輸入層的784 這個數字來源于圖像大小的28 × 28 = 784,輸出層的10 這個數字來源于10 類别分類(數字0 到9,共10 類别)。此外,這個神經網絡有2 個隐藏層,第1 個隐藏層有50 個神經元,第2 個隐藏層有100 個神經元。這個50 和100 可以設定為任何值。

    我們本次學習使用的權重是已經訓練好的權重資料,準确率可以達到94%,權重資料檔案名為sample_weight.pkl。

資料和代碼的下載下傳位址:https://download.csdn.net/download/zhiyeegao/12277801

import pickle
def init_network():
    with open("sample_weight.pkl","rb") as f:
        network = pickle.load(f)
           

    用神經網絡進行預測:

def sigmoid(x):
    return 1/(1+np.exp(-x))

def softmax(a):
    exp_a = np.exp(a)
    sum_exp_a = np.sum(exp_a)
    y = exp_a / sum_exp_a
    return y

def predict(network,x):
    W1,W2,W3 = network['W1'],network['W2'],network['W3']
    b1,b2,b3 = network['b1'],network['b2'],network['b3']
    a1 = np.dot(x,W1)+b1
    z1 = sigmoid(a1)
    a2 = np.dot(z1,W2)+b2
    z2 = sigmoid(a2)
    a3 = np.dot(z2,W3)+b3
    y = softmax(a3)
    return y
           

三、批處理

    用批處理統計訓練标簽和測試标簽每萬份資料中相同的數字的個數是多少。(訓練标簽有六萬個,測試标簽有一萬個)

batch_szie = 10000
all_same_count = 0
for i in range(0,len(label_train),batch_szie):
    label_train_batch = label_train[i:i+batch_szie]
    same_count = 0
    same_count += np.sum(label_train_batch == label_test)
    all_same_count += same_count
    print("每萬份資料中相同數字個數:"+str(same_count))
print("總數:"+str(all_same_count))


每萬份資料中相同數字個數:1008
每萬份資料中相同數字個數:1034
每萬份資料中相同數字個數:941
每萬份資料中相同數字個數:1004
每萬份資料中相同數字個數:1018
每萬份資料中相同數字個數:1014
總數:6019
           

繼續閱讀