Programming Exercise 4:Neural Networks Learning
這周總結一下Week5的作業,需要實作三層神經網絡内部參數的訓練。應用的例子同Week3一緻,均為識别手寫數字。總的來說本周作業較前幾次難度略有提升。
首先看一下作業要求(打星号的需要送出):
資料預處理上與上一周類似,此處不再贅述。
sigmoidGradient.m目的是給出在backpropagation(反向傳播)過程中各層的誤差配置設定。
g=sigmoid(z)*(1-sigmoid(z));
randInitializeWeights.m目的是對神經網絡各層的參數進行初始化,注意這裡的參數不能夠全部相同,否則會導緻訓練後參數全部相同,無法達到精度要求。最簡單的方法就是在随機産生一個相同次元矩陣的基礎上進行适當修改。
epsilon_init = 0.12;
W = rand(L_out, 1 + L_in) * 2 * epsilon_init-epsilon_init;
L_in和L_out是該層網絡的輸入和輸對外連結接數,由于要加一個常數項,輸傳入連結接加1。W中的元素在
,這裡
。
nnCostFunction.m這次誤差函數裡面可是相當有料啊!包含前向傳播和後向傳播兩塊内容。參數矩陣第一層25x401,第二層10x26。
前向傳播:
誤差函數比較好處理。
y是訓練樣本的類别向量5000x10,從初值到需要經過2次正向傳播,對每層參數實施正則化。
%calculate h(x)
a1 = [ones(m, 1) X];
z2 = a1 * Theta1';
a2 = sigmoid(z2);
a2 = [ones(m, 1) a2];
z3 = a2 * Theta2';
h = sigmoid(z3);
%calculate yk
yk = zeros(m, num_labels);
for i = 1:m
yk(i, y(i)) = 1; %y is class vector
end
%costFunction
J = (1/m)* sum(sum(((-yk) .* log(h) - (1 - yk) .* log(1 - h))));
%Regularized cost function,Theta1 and Theta2 are the weight matrices
Theta1_new=Theta1(:,2:size(Theta1,2));
Theta2_new=Theta2(:,2:size(Theta2,2));
J=J+lambda/2/m*(Theta1_new(:)'*Theta1_new(:)+Theta2_new(:)'*Theta2_new(:));
反向傳播:目的在于把每次正向傳播的誤差回報給每層的參數并按照梯度給出的最優方向進行調整。
步驟:1.實施正向傳播到第三層時,算出與真值的誤差。
2.利用sigmoid函數的梯度,算出第二層的誤差矩陣。
3.計算各層的梯度(注意要去掉第一列,即常數列)。
表示第l層的梯度。
4.正則化。
%step 1 and 2
for i=1:m
y_new=zeros(1,num_labels);
y_new(1,y(i))=1;
a1=[1;X(i,:)'];
a2=[1;sigmoid(Theta1*a1)];
a3=sigmoid(Theta2*a2);
det3=a3-y_new';
det2=Theta2'*det3.*sigmoidGradient([1;Theta1*a1]);
det2 = det2(2:end);
Theta1_grad=Theta1_grad+det2*a1';
Theta2_grad=Theta2_grad+det3*a2';
end
%step 3 and 4
Theta1_grad(:,1)=Theta1_grad(:,1)/m;
Theta1_grad(:,2:size(Theta1_grad,2))=Theta1_grad(:,2:size(Theta1_grad,2))/m+...
lambda*Theta1(:,2:size(Theta1,2))/m;
Theta2_grad(:,1)=Theta2_grad(:,1)/m;
Theta2_grad(:,2:size(Theta2_grad,2))=Theta2_grad(:,2:size(Theta2_grad,2))/m+...
lambda*Theta2(:,2:size(Theta2,2))/m;
最終精度為95.06%,這與作業要求相近。
總結一下,反向傳播裡面梯度和誤差矩陣那裡還是有一些似懂非懂;程式裡面還可以再簡化一下,以後再貼上來;本身是數學系的,而且自認為高代學得還可以,還是差點被裡面加減項搞暈,還是要再好好琢磨一下裡面的原理。