天天看點

神經網絡入門及改進優化——CNN識别Fashion-MNIST資料集(Python實作)

适用于 TensorFlow 2.3 及 keras 2.4.3 ,Python版本為3.8

如果你使用新版本的第三方庫,請考慮降級為本文适用的版本,或者自行查閱第三方庫的更新文檔修改代碼。

圖像分類資料集中最常用的是手寫數字識别資料集MNIST 。但大部分模型在MNIST上的分類精度都超過了95%。為了更直覺地觀察算法之間的差異,我們将使用一個圖像内容更加複雜的資料集Fashion-MNIST 

FashionMNIST 是圖像資料集,它是由 Zalando(一家德國的時尚科技公司)旗下的研究部門提供。其涵蓋了來自 10 種類别的共 7 萬個不同商品的正面圖檔。FashionMNIST 的大小、格式和訓練集/測試集劃分與原始的 MNIST 完全一緻。60000/10000 的訓練測試資料劃分,28x28 的灰階圖檔。友善我們進行測試各種神經網絡算法。 該資料集識别難度遠大于原有的MNIST資料集。

神經網絡入門及改進優化——CNN識别Fashion-MNIST資料集(Python實作)
神經網絡入門及改進優化——CNN識别Fashion-MNIST資料集(Python實作)

資料庫導入

所有代碼都用keras.datasets接口來加載fashion_mnist資料,從網絡上直接下載下傳fashion_mnist資料,無需從本地導入,十分友善。

這意味着你可以将代碼中的

(X_train, y_train), (X_test, y_test) = fashion_mnist.load_data()

修改為(X_train, y_train), (X_test, y_test) = mnist.load_data()

就可以直接對原始的MNIST資料集進行訓練和識别。

1.Baseline版本代碼(MLP實作,識别成功率為87.6%)

BaseLine版本用的是MultiLayer Percepton(多層感覺機)。這個網絡結構比較簡單,輸入--->隐含--->輸出。隐含層采用的rectifier linear unit,輸出直接選取的softmax進行多分類。

神經網絡入門及改進優化——CNN識别Fashion-MNIST資料集(Python實作)
import numpy
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Dropout
from keras.utils import np_utils
from keras.datasets import fashion_mnist

seed = 7
numpy.random.seed(seed)
#加載資料
(X_train, y_train), (X_test, y_test) = fashion_mnist.load_data()

num_pixels = X_train.shape[1] * X_train.shape[2]
X_train = X_train.reshape(X_train.shape[0], num_pixels).astype('float32')
X_test = X_test.reshape(X_test.shape[0], num_pixels).astype('float32')

X_train = X_train / 255
X_test = X_test / 255

# 對輸出進行one hot編碼
y_train = np_utils.to_categorical(y_train)
y_test = np_utils.to_categorical(y_test)
num_classes = y_test.shape[1]

# MLP模型
def baseline_model():
     model = Sequential()
     model.add(Dense(num_pixels, input_dim=num_pixels,  activation='relu'))
     model.add(Dense(num_classes,  activation='softmax'))
     model.summary()
     model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
     return model

# 建立模型
model = baseline_model()

# Fit
model.fit(X_train, y_train, validation_data=(X_test, y_test), epochs=10, batch_size=200, verbose=2)

#Evaluation
scores = model.evaluate(X_test, y_test, verbose=0)
print("使用MLP并疊代十次的正确率為:" ,'%.2f' %(scores[1]*100) , "%")
           

 2.CNN卷積神經網絡(疊代20次,識别成功率92%) 

神經網絡入門及改進優化——CNN識别Fashion-MNIST資料集(Python實作)
import numpy
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Dropout
from keras.layers import Flatten
from keras.layers.convolutional import Convolution2D
from keras.layers.convolutional import MaxPooling2D
from keras.utils import np_utils
from keras.datasets import fashion_mnist

seed = 7
numpy.random.seed(seed)

 #加載資料
(X_train, y_train), (X_test, y_test) = fashion_mnist.load_data()
# reshape to be [samples][channels][width][height]
X_train = X_train.reshape(X_train.shape[0],  28, 28, 1).astype('float32')
X_test = X_test.reshape(X_test.shape[0],  28, 28, 1).astype('float32')

 # normalize inputs from 0-255 to 0-1
X_train = X_train / 255
X_test = X_test / 255
 # one hot encode outputs
y_train = np_utils.to_categorical(y_train)
y_test = np_utils.to_categorical(y_test)
num_classes = y_test.shape[1]

# define a simple CNN model
def baseline_model():
    # create model
    model = Sequential()
    model.add(Convolution2D(32,( 5, 5), padding='valid', input_shape=( 28, 28, 1), activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.1))
    model.add(Flatten())
    model.add(Dense(128, activation='relu'))
    model.add(Dense(num_classes, activation='softmax'))
    # Compile model
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    return model

# build the model
model = baseline_model()

# Fit the model
model.fit(X_train, y_train, validation_data=(X_test, y_test), epochs=20, batch_size=128, verbose=2)

# Final evaluation of the model
scores = model.evaluate(X_test, y_test, verbose=0)
print("使用CNN網絡的正确率為:" ,'%.2f' %(scores[1]*100) , "%")
           

附錄

神經網絡入門及改進優化——CNN識别Fashion-MNIST資料集(Python實作)
神經網絡入門及改進優化——CNN識别Fashion-MNIST資料集(Python實作)