最近在看PLA算法,以下是覺得寫得比較好的部落格,通俗易通。
Coursera上台大老師林軒田的機器學習基石這門課,個人覺得講的非常好,非常适合入門。以下是關于這門課的一些相關童鞋的部落格,總結得特别好。
1.http://wizmann.tk/ml-foundations-pla.html
這篇部落格用Python語言描述了PLA算法的過程,還有PLA的改進算法Pocket。
2.http://blog.csdn.net/u013455341/article/details/46747343
假設資料集在圖上的反映的效果如下圖所示,紅色和白色的圈圈分為2類
![](https://img.laitimes.com/img/_0nNw4CM6IyYiwiM6ICdiwiIyVGduV2QvwVe0lmdhJ3ZvwFM38CXlZHbvN3cpR2Lc1TPB10QGtWUCpEMJ9CXsxWam9CXwADNvwVZ6l2c052bm9CXUJDT1wkNhVzLcRnbvZ2LcZXUYpVd1kmYr50MZV3YyI2cKJDT29GRjBjUIF2LcRHelR3LcJzLctmch1mclRXY39TN0YDMyMDN4EzMyYDM2EDMy8CX0Vmbu4GZzNmLn9Gbi1yZtl2Lc9CX6MHc0RHaiojIsJye.jpg)
#include<iostream>
#include<vector>
#include<cstdlib>
using namespace std;
//以二維空間為例,x1 x2為屬性,x0是假設機器從原點處開始,然後再在被測試的資料集裡找一個數開始訓練
struct Item{
int x0;
double x1,x2;
int label;
};
//權重結構體,w1 w2為屬性x1 x2的權重,初始值全設為0
struct Weight{
double w0,w1,w2;//
}Wit0={0,0,0};
//符号函數,根據向量内積和的正負和資料集本身的标簽進行比較,如果不一樣,則需要調整權重
int sign(double x)
{
if(x>0)
return 1;
else if(x<0)
return -1;
else return 0;
}
//兩個向量的内積
double DotPro(Item item,Weight wight)
{
return item.x0*wight.w0+item.x1*wight.w1+item.x2*wight.w2;
}
//更新權重
Weight UpdateWeight(Item item,Weight weight)
{
Weight newWeight;
newWeight.w0=weight.w0+item.x0*item.label;
newWeight.w1=weight.w1+item.x1*item.label;
newWeight.w2=weight.w2+item.x2*item.label;
return newWeight;
}
int main()
{
vector<Item> ivec;
Item data;
cout<<"Please input x1,x2,label;"<<endl;
while(cin>>data.x1>>data.x2>>data.label)
{
data.x0=0;
ivec.push_back(data);
//cout<<"Please input x1,x2,label:"<<endl;
}
//cin.clear();
Weight wit=Wit0;
for(vector<Item>::iterator iter=ivec.begin();iter!=ivec.end();++iter)
{
if((*iter).label!=sign(DotPro(*iter,wit)))
{
wit=UpdateWeight(*iter,wit);
iter=ivec.begin();//
}
}
cout<<wit.w0<<" "<<wit.w1<<" "<<wit.w2<<" "<<endl;
system("pause");
}
測試資料集如下
那麼PLA算法修正多少次才能得出左右的線性分類的解呢?還是用上面的測試資料集,用以下代碼實作
#include<fstream>
#include<iostream>
#include<vector>
using namespace std;
#define DEMENSION 3
double weight[DEMENSION];//權重值
int step = 0;//修改次數
int n = 0;//訓練樣本數
char *file = "training_data3.txt";//讀取檔案名
//存儲訓練樣本,input為x and y,output為label
struct record{
double input[DEMENSION];
int output;
};
//把記錄存在向量裡而不是存在結構體數組内,這樣可以根據實際一項項添加
vector<record> trainingSet;
//将資料讀入訓練樣本向量中
void getData(ifstream &datafile)
{
while(!datafile.eof())
{
record curRecord;
curRecord.input[0] = 1;
int i;
for(i = 1; i < DEMENSION; i++){
datafile>>curRecord.input[i];
}
datafile>>curRecord.output;
trainingSet.push_back(curRecord);
}
datafile.close();
n = trainingSet.size();
}
//計算sign值
int sign(double x){
if(x <= 0)return -1;
else return 1;
}
//兩向量相加(實際為數組相加),将結果儲存在第一個數組内,用于計算w(i+1)=w(i)+y*x
void add(double *v1,double *v2,int demension){
int i;
for(i = 0;i < demension; i++)v1[i] += v2[i];
}
//計算兩數值相乘值,用于判斷w*x是否小于0,若小于0要執行修正算法
double multiply(double *v1,double *v2,int demension){
double temp = 0.0;
int i;
for(i = 0; i < demension; i++)temp += v1[i] * v2[i];
return temp;
}
//計算實數num與向量乘積放在result中,用于計算y*x
void multiply(double *result,double *v,int demension,int num){
int i;
for(i = 0; i < demension; i++)result[i] = num * v[i];
}
void PLA()
{
int correctNum = 0;//目前連續正确樣本數,當等于n則表明輪完一圈,則表示全部正确,算法結束
int index = 0;//目前正在計算第幾個樣本
bool isFinished = 0;//算法是否全部完成的表示,=1表示算法結束
while(!isFinished){
if(trainingSet[index].output == sign(multiply(weight,trainingSet[index].input,DEMENSION)))correctNum++;//目前樣本無錯,連續正确樣本數+1
else{//出錯,執行修正算法
double temp[DEMENSION];
multiply(temp,trainingSet[index].input,DEMENSION,trainingSet[index].output);//計算y*x
add(weight,temp,DEMENSION);//計算w(i+1)=w(i)+y*x
step++;//進行一次修正,修正次數+1
correctNum = 0;//由于出錯了,連續正确樣本數歸0
cout<<"step"<<step<<":"<<endl<<"index="<<index<<" is wrong"<<endl;
}
if(index == n-1)index = 0;
else index++;
if(correctNum == n)isFinished = 1;
}
cout<<"total step:"<<step<<endl;
}
void main()
{
ifstream dataFile(file);
if(dataFile.is_open()){
getData(dataFile);
}
else{
cout<<"出錯,檔案打開失敗!"<<endl;
exit(1);
}
int i;
for(i = 0; i < DEMENSION; i++)weight[i] = 0.0;
PLA();
}
得到的結果如下:
這裡還有一個4維的測試資料,大家可以拿過去試試~,注意程式中的DEMENSION要改成5哦,讀取檔案中的資料,靠這個來分組的哦。
從https://d396qusza40orc.cloudfront.net/ntumlone%2Fhw1%2Fhw1_15_train.dat 下為訓練資料
答案是45哦~