天天看點

使用MATLAB神經網絡生成手寫數字識别程式

作者:sunbhtt

程式

% 讀取MATLAB自帶數字圖像資料集,資料集有10000幅0-9圖像,各數字有1000幅圖像
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet', ...
    'nndemos','nndatasets','DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
    'IncludeSubfolders',true,'LabelSource','foldernames');

% 将每個标簽檔案夾中的檔案随機拆分為兩組,750個為imdsTrain,其餘為imdsTest
numTrainingFiles = 750;
[imdsTrain,imdsTest] = splitEachLabel(imds,numTrainingFiles,'randomize');

% 定義一個簡單的三層CNN網絡結構
layers = [
    % 輸入層,接收28×28×1的灰階圖像
    imageInputLayer([28 28 1])
    
    % 卷積層,使用3×3的卷積核,輸出8個特征圖,保持邊界不變
    convolution2dLayer(3,8,'Padding','same')
    % 歸一化層,對卷積層的輸出進行歸一化處理,提高訓練效率和泛化能力
    batchNormalizationLayer
    % 激活層,使用ReLU函數作為激活函數,增加非線性特征
    reluLayer
    
    % 池化層,使用2×2的最大池化核,步長為2,降低特征圖的尺寸和參數數量
    maxPooling2dLayer(2,'Stride',2)
    
    % 卷積層,使用3×3的卷積核,輸出16個特征圖,保持邊界不變
    convolution2dLayer(3,16,'Padding','same')
    % 歸一化層,對卷積層的輸出進行歸一化處理,提高訓練效率和泛化能力
    batchNormalizationLayer
    % 激活層,使用ReLU函數作為激活函數,增加非線性特征
    reluLayer
    
    % 池化層,使用2×2的最大池化核,步長為2,降低特征圖的尺寸和參數數量
    maxPooling2dLayer(2,'Stride',2)
    
    % 卷積層,使用3×3的卷積核,輸出32個特征圖,保持邊界不變
    convolution2dLayer(3,32,'Padding','same')
    % 歸一化層,對卷積層的輸出進行歸一化處理,提高訓練效率和泛化能力
    batchNormalizationLayer
    % 激活層,使用ReLU函數作為激活函數,增加非線性特征
    reluLayer
    
    % 全連接配接層,将卷積層的輸出展平為一維向量,并連接配接到10個神經元上,對應10個類别(0-9)
    fullyConnectedLayer(10)
    % softmax層,将全連接配接層的輸出轉換為機率分布
    softmaxLayer
    % 分類層,根據softmax層的輸出和真實标簽計算損失函數,并評估分類準确率
    classificationLayer];

% 設定訓練參數
options = trainingOptions('sgdm', ...
    'InitialLearnRate',0.01, ... % 初始學習率為0.01
    'MaxEpochs',5, ... % 最大疊代次數為5次
        'Shuffle','every-epoch', ... % 每次疊代都打亂資料順序
    'ValidationData',imdsTest, ... % 使用imdsTest作為驗證資料集
    'ValidationFrequency',30, ... % 每30次疊代進行一次驗證
    'Verbose',false, ... % 不在指令視窗顯示訓練過程資訊
    'Plots','training-progress'); % 繪制訓練進度圖

% 訓練網絡,并儲存訓練好的模型
net = trainNetwork(imdsTrain,layers,options);
save net

% 使用classify函數對測試資料進行分類,并計算準确率
YPred = classify(net,imdsTest);
YTest = imdsTest.Labels;
accuracy = sum(YPred == YTest)/numel(YTest);

% 使用imshow函數顯示一些測試圖檔和分類結果
figure;
perm = randperm(2500,20);
for i = 1:20
    subplot(4,5,i);
    s = classify(net,readimage(imdsTest,perm(i)));
    imshow(imdsTest.Files{perm(i)});
    title(string(s));
end

           
使用MATLAB神經網絡生成手寫數字識别程式

訓練進度

識别結果

使用MATLAB神經網絡生成手寫數字識别程式

識别結果

繼續閱讀