天天看點

MATLAB 使用CNN拟合回歸模型預測手寫數字的旋轉角度(卷積神經網絡)

學習來源自mathworks的官方範例,個人學習使用,在個人項目上可以按照需求變化資料集來實作CNN回歸計算

%% 加載資料
%% 資料集包含手寫數字的合成圖像,以及每幅圖像旋轉的對應角度(以角度為機關)。
%% 使用digitTrain4DArrayData和digitTest4DArrayData将訓練和驗證圖像加載為4D數組。
%% 輸出YTrain和YValidation是以角度為機關的旋轉角度。每個訓練和驗證資料集包含5000張圖像。
[XTrain, ~, Ytrain] = digitTrain4DArrayData;
[XValidation, ~, YValidation] = digitTest4DArrayData;
%% 随機顯示20張訓練圖像
numTrainImages = numel(YTrrain);
figure;
idx = randperm(numTrainImages, 20);
for i = 1 : numel(idx)
    subplot(4, 5, i);
    imshow(XTrain(:, :, :, idx(i)))
    drawnow
end           
MATLAB 使用CNN拟合回歸模型預測手寫數字的旋轉角度(卷積神經網絡)

%% 資料歸一化處理

%% 當訓練神經網絡時,確定你的資料在網絡的所有階段都是标準化的通常是有幫助的。

%% 歸一化有助于使用梯度下降來穩定和加速網絡訓練。

%% 如果您的資料規模太小,那麼損失可能會變成NaN,并且在教育訓練期間網絡參數可能會出現分歧。

%% 标準化資料的常用方法包括重新标定資料,使其範圍變為[0,1]或使其均值為0,标準差為1。

%{

你可以标準化以下資料:

1、輸入資料。在将預測器輸入到網絡之前對它們進行規範化。在本例中,輸入圖像已經标準化為[0,1]範圍。

2、層輸出。您可以使用批處理規範化層對每個卷積和完全連接配接層的輸出進行規範化。

3、響應。如果使用批處理規範化層對網絡末端的層輸出進行規範化,則在開始訓練時對網絡的預測進行規範化。

        如果響應的規模與這些預測非常不同,那麼網絡訓練可能無法收斂。

        如果你的回答沒有得到很好的擴充,那麼試着将其标準化,看看網絡教育訓練是否有所改善。

        如果在訓練前對響應進行規範化,則必須轉換訓練網絡的預測,以獲得原始響應的預測。

%}

%% 一般來說,資料不必完全标準化。

%% 但是,如果在本例中訓練網絡來預測100*YTrain或YTrain+500而不是YTrain,那麼損失就變成NaN,

%% 當訓練開始時,網絡參數就會出現分歧。

%% 即使網絡預測aY + b和網絡預測Y之間的唯一差別是重新調整最終完全連接配接層的權重和偏差,這些結果仍然會出現。

%% 如果輸入或響應的分布非常不均勻或傾斜,還可以執行非線性轉換(例如,取對數)

%% 繪制響應分布:在分類問題中,輸出是類機率,類機率總是歸一化的。
figure;
histogram(YTrain)
axis tight
ylabel('Counts')
xlabel('Rotation Angle')           
MATLAB 使用CNN拟合回歸模型預測手寫數字的旋轉角度(卷積神經網絡)

通常,資料不必完全歸一化。但是,如果在此示例中訓練網絡來預測 

100*YTrain

 或 

YTrain+500

 而不是 

YTrain

,則損失将變為 

NaN

,并且網絡參數在訓練開始時會發生偏離。即使預測 aY + b 的網絡與預測 Y 的網絡之間的唯一差異是對最終全連接配接層的權重和偏置的簡單重新縮放,也會出現這些結果。

如果輸入或響應的分布非常不均勻或偏斜,您還可以在訓練網絡之前對資料執行非線性變換(例如,取其對數)。

%% 建立網絡層
%% 第一層定義輸入資料的大小和類型。輸入的圖像大小為28×28×1。建立與訓練圖像大小相同的圖像輸入層。
%% 網絡的中間層定義了網絡的核心架構,大部分計算和學習都在這個架構中進行。
%% 最後一層定義輸出資料的大小和類型。對于回歸問題,全連接配接層必須先于網絡末端的回歸層。
layers = [
    imageInputLayer([28 28 1])
    batchNormalizationLayer
    reluLayer
   
    averagePooling2dLayer(2, 'Stride', 2)
   
    convolution2dLayer(3, 16, 'Padding', 'same')
    batchNormalizationLayer
    reluLayer
   
    averagePooling2dLayer(2, 'Stride', 2)
   
    convolution2dLayer(3, 32, 'Padding', 'same')
    batchNormalizationLayer
    reluLayer
   
    concolution2dLayer(3, 32, 'Padding', 'same')
    batchNormalizationLayer
    reluLayer
   
    dropoutLayer(0.2)
    fullyConnectedLayer(1)
    regressionLayer];
%% 訓練網絡——Options
%% Train for 30 epochs 學習率0.001 在20個epoch後降低學習率。
%% 通過指定驗證資料和驗證頻率,監控教育訓練過程中的網絡準确性。
%% 根據訓練資料對網絡進行訓練,并在訓練過程中定期對驗證資料進行精度計算。
%% 驗證資料不用于更新網絡權重。打開訓練進度圖,并關閉指令視窗輸出。
miniBatchSize = 128;
validationFrequency = floor(numel(YTrain) / miniBatchSize);
options = trainingOptions('sgdm', ...
    'MiniBatchSize', miniBatchSize, ...
    'MaxEpochs', 30, ...
    'InitialLearnRate', 1e-3, ...
    'LearnRateSchedule', 'piecewise', ...
    'LearnRateDropFactor', 0.1, ...
    'LearnRateDropPeriod', 20, ...
    'Shuffle', 'every-epoch', ...
    'ValidationData', {XValidation, YValidation}, ...
    'ValidationFrequency', validationFrequency, ...
    'Plots', 'training-progress', ...
    'Verbose', false);
net = trainNetwork(XTrain, YTrain, layer, options)           

使用 

trainNetwork

 建立網絡。如果存在相容的 GPU,此指令會使用 GPU。否則,

trainNetwork

 将使用 CPU。在 GPU 上進行訓練需要具有 3.0 或更高計算能力的支援 CUDA® 的 NVIDIA® GPU。

檢查 

net

 的 

Layers

 屬性中包含的網絡架構的詳細資訊。

net.Layers           

基于驗證資料評估準确度來測試網絡性能。使用 

predict

 預測驗證圖像的旋轉角度。

YPredicted = predict(net,XValidation);           

評估性能

通過計算以下值來評估模型性能:

  1. 在可接受誤差界限内的預測值的百分比
  2. 預測旋轉角度和實際旋轉角度的均方根誤差 (RMSE)

計算預測旋轉角度和實際旋轉角度之間的預測誤差。

predictionError = YValidation - YPredicted;           

計算在實際角度的可接受誤差界限内的預測值的數量。将門檻值設定為 10 度。計算此門檻值範圍内的預測值的百分比。

thr = 10;
numCorrect = sum(abs(predictionError) < thr);
numValidationImages = numel(YValidation);

accuracy = numCorrect/numValidationImages           

使用均方根誤差 (RMSE) 來衡量預測旋轉角度和實際旋轉角度之間的差異。

squares = predictionError.^2;
rmse = sqrt(mean(squares))           

顯示每個數字類的殘差箱線圖

boxplot

 函數需要一個矩陣,其中各個列對應于各個數字類的殘差。

驗證資料按數字類 0-9 對圖像進行分組,每組包含 500 個樣本。使用 

reshape

 按數字類對殘差進行分組。

residualMatrix = reshape(predictionError,500,10);           

residualMatrix

 的每列對應于每個數字的殘差。使用 

boxplot

 (Statistics and Machine Learning Toolbox) 為每個數字建立殘差箱線圖。

figure
boxplot(residualMatrix,...
    'Labels',{'0','1','2','3','4','5','6','7','8','9'})
xlabel('Digit Class')
ylabel('Degrees Error')
title('Residuals')           

準确度最高的數字類具有接近于零的均值和很小的方差。

您可以使用 Image Processing Toolbox 中的函數來擺正數字并将它們顯示在一起。使用 

imrotate

 (Image Processing Toolbox) 根據預測的旋轉角度旋轉 49 個樣本數字。

idx = randperm(numValidationImages,49);
for i = 1:numel(idx)
    image = XValidation(:,:,:,idx(i));
    predictedAngle = YPredicted(idx(i));  
    imagesRotated(:,:,:,i) = imrotate(image,predictedAngle,'bicubic','crop');
end           

顯示原始數字以及校正旋轉後的數字。您可以使用 

montage

 (Image Processing Toolbox) 将數字顯示在同一個圖像上。

figure
subplot(1,2,1)
montage(XValidation(:,:,:,idx))
title('Original')

subplot(1,2,2)
montage(imagesRotated)
title('Corrected')           
MATLAB 使用CNN拟合回歸模型預測手寫數字的旋轉角度(卷積神經網絡)

繼續閱讀