天天看點

keras入門(一)——遷移VGG16模型訓練mnist資料集實作手寫數字識别

作為一個剛入門keras的小白,實戰的時候參照網上修改VGG16模型訓練mnist資料集實作手寫數字識别,掉進了不少坑,走了不少彎路,也學習了很多知識。下面跟大家分享一下,有問題歡迎大家評論指出~

整個程式參考了以下文檔和視訊:

人工智能深度學習第21講:遷移學習

b站視訊

(一)修改VGG16模型

model_vgg=VGG16(include_top=False,weights="imagenet",input_shape=(48,48,3))
for layer in model_vgg.layers:
    layer.trainable=False
model=Flatten(name="flatten")(model_vgg.output)
model=Dense(10,activation="softmax")(model)
model_vgg_mnist=Model(inputs=model_vgg.input,outputs=model,name="vgg16")
           

1.keras中VGG16模型

VGG16模型,權重由ImageNet訓練而來,模型的預設輸入尺寸是224x224,但是最小是48x48,這個一定要注意。

keras.applications.vgg16.VGG16(include_top=True,weights=‘imagenet’,input_tensor=None,input_shape=None,pooling=None,classes=1000)

參數:

include_top:是否保留頂層的3個全連接配接網絡

weights:None代表随機初始化,即不加載預訓練權重。'imagenet’代表加載預訓練權重

input_tensor:可填入Keras tensor作為模型的圖像輸出tensor

input_shape:可選,僅當include_top=False有效,應為長為3的tuple,指明輸入圖檔的shape,圖檔的寬高必須大于48,如(200,200,3)

2.for循環

固定住模型中卷積層和池化層的參數,不讓他們進行訓練。可以從截圖中看到,要訓練的參數大大減少。

keras入門(一)——遷移VGG16模型訓練mnist資料集實作手寫數字識别

3.添加其他層,組合

為了減少訓練時間我就直接添加Flatten層然後分類了,大家可以自行添加全連接配接層。

然後通過Model将自己添加的層和VGG模型組合起來。

(二)編譯模型

sgd=SGD(lr=0.05,decay=1e-5)
model_vgg_mnist.compile(optimizer=sgd,loss="categorical_crossentropy",metrics=['accuracy'])
           

(三)加載mnist資料集,修改資料集尺寸、類型

#加載mnist資料集,為了縮短訓練時間取資料集前10000個
(x_train,y_train),(x_test,y_test)=mnist.load_data()
x_train,y_train=x_train[:10000],y_train[:10000]
x_test,y_test=x_test[:10000],y_test[:10000]

#修改資料集的尺寸、将灰階圖像轉換為rgb圖像
x_train=[cv2.cvtColor(cv2.resize(i,(48,48)),cv2.COLOR_GRAY2BGR)for i in x_train]
x_train=np.concatenate([arr[np.newaxis]for arr in x_train]).astype('float32')
x_train=x_train/255
x_test=[cv2.cvtColor(cv2.resize(i,(48,48)),cv2.COLOR_GRAY2BGR)for i in x_test]
x_test=np.concatenate([arr[np.newaxis]for arr in x_test]).astype('float32')
x_test=x_test/255
           

1.着重說一下修改圖像的尺寸和類型!

我一開始不懂np.concatenate和np.newaxis,就尋思用for循環不也能來實作修改圖像的尺寸和類型嘛。

for i in range(len(x_train)):
    x_train[i]=cv2.resize(x_train[i],(48,48))
           

然後就殘忍的報錯了:

keras入門(一)——遷移VGG16模型訓練mnist資料集實作手寫數字識别

我想之前我用opencv這個resize函數處理圖像沒有問題啊,圖像不就是一個數組嘛,那我取出來x_train其中一個元素不也是數組嘛,放進resize為什麼不行。

經過我苦苦掙紮後,才明白“對象”的意義!resize輸入參數雖然是數組,但是輸入的是一個獨立的數組,也就是數組這個對象。而我取出來x_train的一個元素,這個元素雖然代表一個圖檔,但是他不是獨立的一個,他是x_train這個數組對象的一部分。是以程式通過list清單的方式改變每一個圖像的尺寸和類型,然後又用np.concatenate和np.newaxis将x_train這個清單還原成一個數組。

2.np.concatenate和np.newaxis

首先x_train這個清單裡面(上一步把x_train清單化了已經),每一個元素都是一個圖檔,大小為(48, 48, 3)。我們需要得到的x_train數組的shape為(10000,48,48,3)。是以我們需要:

  1. 第一步:通過np.newaxis函數把每一個圖檔增加一個次元變成(1,48,48,3)。是以就有了程式中的arr[np.newaxis]。
  2. 第二步:通過np.concatenate把每個數組連接配接起來組成一個新的x_train數組,連接配接後的x_train數組shape為(10000,48,48,3)

(四)編碼、訓練、評估、儲存模型

#編碼
y_train=np_utils.to_categorical(y_train,10)
y_test=np_utils.to_categorical(y_test,10)

#訓練模型
model_vgg_mnist.fit(x_train,y_train,epochs=3,batch_size=100)

#評估模型
result=model_vgg_mnist.evaluate(x_test,y_test)
print("loss:",result[0])
print("acc:",result[1])

#儲存模型
model_vgg_mnist.save('my_mnist_vgg16.model')
           

因為用cpu訓練,條件有限,為節省時間省略了全連接配接層,隻訓練了3輪,大約十五分鐘,正确率大約80%

(五)測試

model = load_model('my_mnist_vgg16.model')

img = cv2.imread('F:/Qin/Produce/keras/other_mnist_cnn/test_pic/4.jpg')
cv2.imshow("圖檔:4",img)
img =img.astype('float32')
img=cv2.resize(img,(48,48))
img = (img.reshape(1,48,48,3))/ 255

predict = model.predict(img)
predict=np.argmax(predict,axis=1)

print('識别為:')
print(predict)

cv2.waitKey(0)
           
keras入門(一)——遷移VGG16模型訓練mnist資料集實作手寫數字識别

預測模型類别這裡也要注意!

不能直接寫成predict = model.predict_classes(img),會報錯。

keras入門(一)——遷移VGG16模型訓練mnist資料集實作手寫數字識别

原因:

參考網址

The predict_classes method is only available for the Sequential class (which is the class of your first model) but not for the Model class (the class of your second model).

With the Model class, you can use the predict method which will give you a vector of probabilities and then get the argmax of this vector (with np.argmax(y_pred1,axis=1)).

第一次用的網絡是在model=Sequential()下添加子產品的的方法,Sequential class可以使用model.predict_classes()

第二次用的網絡是編寫好網絡結構後使用model=Model(input=mnist_input,outputs=output)綜合起來的方法,也就是Model class,無法使用model.predict_classes(),但是可以使用組合形式預測。

(六)測試結果

用自己手寫的0~9張數字圖檔進行識别,有4張識别正确,網絡還需要進一步完善和訓練。

繼續閱讀