天天看點

PLA算法---C++

最近在看PLA算法,以下是覺得寫得比較好的部落格,通俗易通。

Coursera上台大老師林軒田的機器學習基石這門課,個人覺得講的非常好,非常适合入門。以下是關于這門課的一些相關童鞋的部落格,總結得特别好。

1.http://wizmann.tk/ml-foundations-pla.html 

這篇部落格用Python語言描述了PLA算法的過程,還有PLA的改進算法Pocket。

2.http://blog.csdn.net/u013455341/article/details/46747343

假設資料集在圖上的反映的效果如下圖所示,紅色和白色的圈圈分為2類

PLA算法---C++
#include<iostream>  
#include<vector> 
#include<cstdlib>

using namespace std;  
//以二維空間為例,x1 x2為屬性,x0是假設機器從原點處開始,然後再在被測試的資料集裡找一個數開始訓練 
struct Item{  
    int x0;  
    double x1,x2;  
    int label;  
};  
//權重結構體,w1 w2為屬性x1 x2的權重,初始值全設為0  
struct Weight{  
    double w0,w1,w2;//  
}Wit0={0,0,0};  
  
//符号函數,根據向量内積和的正負和資料集本身的标簽進行比較,如果不一樣,則需要調整權重  
int sign(double x)
{  
    if(x>0)  
        return 1;  
    else if(x<0)  
        return -1;  
    else return 0;  
}  
//兩個向量的内積  
double DotPro(Item item,Weight wight)
{  
    return item.x0*wight.w0+item.x1*wight.w1+item.x2*wight.w2;  
}  
//更新權重  
Weight UpdateWeight(Item item,Weight weight)
{  
    Weight newWeight;  
    newWeight.w0=weight.w0+item.x0*item.label;  
    newWeight.w1=weight.w1+item.x1*item.label;  
    newWeight.w2=weight.w2+item.x2*item.label;  
    return newWeight;  
}  
int main()
{   
    vector<Item> ivec;  
    Item data;  
    cout<<"Please input x1,x2,label;"<<endl;  
    while(cin>>data.x1>>data.x2>>data.label)
	{   
        data.x0=0;  
        ivec.push_back(data); 
		//cout<<"Please input x1,x2,label:"<<endl;
    }  
	//cin.clear();
    Weight wit=Wit0;  
    for(vector<Item>::iterator iter=ivec.begin();iter!=ivec.end();++iter)
	{  
        if((*iter).label!=sign(DotPro(*iter,wit)))
		{  
            wit=UpdateWeight(*iter,wit);  
            iter=ivec.begin();//  
        }  
    }  
    cout<<wit.w0<<" "<<wit.w1<<" "<<wit.w2<<" "<<endl;
	system("pause");
	
}  
           

測試資料集如下

PLA算法---C++

那麼PLA算法修正多少次才能得出左右的線性分類的解呢?還是用上面的測試資料集,用以下代碼實作

#include<fstream>
#include<iostream>
#include<vector>
using namespace std;

#define DEMENSION 3

double weight[DEMENSION];//權重值
int step = 0;//修改次數
int n = 0;//訓練樣本數
char *file = "training_data3.txt";//讀取檔案名

//存儲訓練樣本,input為x and y,output為label
struct record{
    double input[DEMENSION];
    int output;
};

//把記錄存在向量裡而不是存在結構體數組内,這樣可以根據實際一項項添加
vector<record> trainingSet;

//将資料讀入訓練樣本向量中
void getData(ifstream &datafile)
{
    while(!datafile.eof())
    {
        record curRecord;
        curRecord.input[0] = 1;
        int i;
        for(i = 1; i < DEMENSION; i++){
            datafile>>curRecord.input[i];
        }
        datafile>>curRecord.output;
        trainingSet.push_back(curRecord);
    }
    datafile.close();
    n = trainingSet.size(); 
}

//計算sign值
int sign(double x){
    if(x <= 0)return -1;
    else return 1;
}

//兩向量相加(實際為數組相加),将結果儲存在第一個數組内,用于計算w(i+1)=w(i)+y*x
void add(double *v1,double *v2,int demension){
    int i;
    for(i = 0;i < demension; i++)v1[i] += v2[i];
}

//計算兩數值相乘值,用于判斷w*x是否小于0,若小于0要執行修正算法
double multiply(double *v1,double *v2,int demension){
    double temp = 0.0;
    int i;
    for(i = 0; i < demension; i++)temp += v1[i] * v2[i];
    return temp;
}

//計算實數num與向量乘積放在result中,用于計算y*x
void multiply(double *result,double *v,int demension,int num){
    int i;
    for(i = 0; i < demension; i++)result[i] = num * v[i];
}

void PLA()
{
    int correctNum = 0;//目前連續正确樣本數,當等于n則表明輪完一圈,則表示全部正确,算法結束
    int index = 0;//目前正在計算第幾個樣本
    bool isFinished = 0;//算法是否全部完成的表示,=1表示算法結束
    while(!isFinished){
        if(trainingSet[index].output == sign(multiply(weight,trainingSet[index].input,DEMENSION)))correctNum++;//目前樣本無錯,連續正确樣本數+1
        else{//出錯,執行修正算法
            double temp[DEMENSION];
            multiply(temp,trainingSet[index].input,DEMENSION,trainingSet[index].output);//計算y*x
            add(weight,temp,DEMENSION);//計算w(i+1)=w(i)+y*x
            step++;//進行一次修正,修正次數+1
            correctNum = 0;//由于出錯了,連續正确樣本數歸0
            cout<<"step"<<step<<":"<<endl<<"index="<<index<<" is wrong"<<endl;
        }
        if(index == n-1)index = 0;
        else index++;
        if(correctNum == n)isFinished = 1;
    }
    cout<<"total step:"<<step<<endl;
}

void main()
{
    ifstream dataFile(file);
    if(dataFile.is_open()){
        getData(dataFile);
    }
    else{
        cout<<"出錯,檔案打開失敗!"<<endl;
        exit(1);
    }

    int i;
    for(i = 0; i < DEMENSION; i++)weight[i] = 0.0;
    PLA();
}
           

得到的結果如下:

PLA算法---C++

這裡還有一個4維的測試資料,大家可以拿過去試試~,注意程式中的DEMENSION要改成5哦,讀取檔案中的資料,靠這個來分組的哦。

從https://d396qusza40orc.cloudfront.net/ntumlone%2Fhw1%2Fhw1_15_train.dat 下為訓練資料

答案是45哦~

繼續閱讀