程式
% 讀取MATLAB自帶數字圖像資料集,資料集有10000幅0-9圖像,各數字有1000幅圖像
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet', ...
'nndemos','nndatasets','DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
'IncludeSubfolders',true,'LabelSource','foldernames');
% 将每個标簽檔案夾中的檔案随機拆分為兩組,750個為imdsTrain,其餘為imdsTest
numTrainingFiles = 750;
[imdsTrain,imdsTest] = splitEachLabel(imds,numTrainingFiles,'randomize');
% 定義一個簡單的三層CNN網絡結構
layers = [
% 輸入層,接收28×28×1的灰階圖像
imageInputLayer([28 28 1])
% 卷積層,使用3×3的卷積核,輸出8個特征圖,保持邊界不變
convolution2dLayer(3,8,'Padding','same')
% 歸一化層,對卷積層的輸出進行歸一化處理,提高訓練效率和泛化能力
batchNormalizationLayer
% 激活層,使用ReLU函數作為激活函數,增加非線性特征
reluLayer
% 池化層,使用2×2的最大池化核,步長為2,降低特征圖的尺寸和參數數量
maxPooling2dLayer(2,'Stride',2)
% 卷積層,使用3×3的卷積核,輸出16個特征圖,保持邊界不變
convolution2dLayer(3,16,'Padding','same')
% 歸一化層,對卷積層的輸出進行歸一化處理,提高訓練效率和泛化能力
batchNormalizationLayer
% 激活層,使用ReLU函數作為激活函數,增加非線性特征
reluLayer
% 池化層,使用2×2的最大池化核,步長為2,降低特征圖的尺寸和參數數量
maxPooling2dLayer(2,'Stride',2)
% 卷積層,使用3×3的卷積核,輸出32個特征圖,保持邊界不變
convolution2dLayer(3,32,'Padding','same')
% 歸一化層,對卷積層的輸出進行歸一化處理,提高訓練效率和泛化能力
batchNormalizationLayer
% 激活層,使用ReLU函數作為激活函數,增加非線性特征
reluLayer
% 全連接配接層,将卷積層的輸出展平為一維向量,并連接配接到10個神經元上,對應10個類别(0-9)
fullyConnectedLayer(10)
% softmax層,将全連接配接層的輸出轉換為機率分布
softmaxLayer
% 分類層,根據softmax層的輸出和真實标簽計算損失函數,并評估分類準确率
classificationLayer];
% 設定訓練參數
options = trainingOptions('sgdm', ...
'InitialLearnRate',0.01, ... % 初始學習率為0.01
'MaxEpochs',5, ... % 最大疊代次數為5次
'Shuffle','every-epoch', ... % 每次疊代都打亂資料順序
'ValidationData',imdsTest, ... % 使用imdsTest作為驗證資料集
'ValidationFrequency',30, ... % 每30次疊代進行一次驗證
'Verbose',false, ... % 不在指令視窗顯示訓練過程資訊
'Plots','training-progress'); % 繪制訓練進度圖
% 訓練網絡,并儲存訓練好的模型
net = trainNetwork(imdsTrain,layers,options);
save net
% 使用classify函數對測試資料進行分類,并計算準确率
YPred = classify(net,imdsTest);
YTest = imdsTest.Labels;
accuracy = sum(YPred == YTest)/numel(YTest);
% 使用imshow函數顯示一些測試圖檔和分類結果
figure;
perm = randperm(2500,20);
for i = 1:20
subplot(4,5,i);
s = classify(net,readimage(imdsTest,perm(i)));
imshow(imdsTest.Files{perm(i)});
title(string(s));
end
訓練進度
識别結果
識别結果