天天看點

決策樹學習 之 ID3 C++STL代碼實作

很久沒寫含這麼多stl的程式了,很故意的用set,map,vector,熟手一下。

也記錄一下吧,雖然寫得比較渣。

三個檔案:

測試資料:data.txt

[plain] view plain copy

  1. D1    Sunny        Hot    High        Weak    No  
  2. D2    Sunny        Hot    High        Strong    No  
  3. D3    Overcast    Hot    High        Weak    Yes  
  4. D4    Rain        Mild    High        Weak    Yes  
  5. D5    Rain        Cool    Normal        Weak    Yes  
  6. D6    Rain        Cool    Normal        Strong    No  
  7. D7    Overcast    Cool    Normal        Strong    Yes  
  8. D8    Sunny        Mild    High        Weak    No  
  9. D9    Sunny        Cool    Normal        Weak    Yes  
  10. D10    Rain        Mild    Normal        Weak    Yes  
  11. D11    Sunny        Mild    Normal        Strong    Yes  
  12. D12    Overcast    Mild    High        Strong    Yes  
  13. D13    Overcast    Hot    Normal        Weak    Yes  
  14. D14    Rain        Mild    High        Strong    No  

程式頭檔案:id3.h

[cpp] view plain copy

  1. #ifndef ID3_H  
  2. #define ID3_H  
  3. #include<fstream>  
  4. #include<iostream>  
  5. #include<vector>  
  6. #include<map>  
  7. #include<set>  
  8. #include<cmath>  
  9. using namespace std;  
  10. const int DataRow=14;  
  11. const int DataColumn=6;  
  12. struct Node  
  13. {  
  14.     double value;//代表此時yes的機率。  
  15.     int attrid;  
  16.     Node * parentNode;  
  17.     vector<Node*> childNode;  
  18. };  
  19. #endif  

程式源檔案id3.cpp

[cpp] view plain copy

  1. #include "id3.h"  
  2. string DataTable[DataRow][DataColumn];  
  3. map<string,int> str2int;  
  4. set<int> S;  
  5. set<int> Attributes;  
  6. string attrName[DataColumn]={"Day","Outlook","Temperature","Humidity","Wind","PlayTennis"};  
  7. string attrValue[DataColumn][DataRow]=  
  8. {  
  9.     {},//D1,D2這個屬性不需要  
  10.     {"Sunny","Overcast","Rain"},  
  11.     {"Hot","Mild","Cool"},  
  12.     {"High","Normal"},  
  13.     {"Weak","Strong"},  
  14.     {"No","Yes"}  
  15. };  
  16. int attrCount[DataColumn]={14,3,3,2,2,2};  
  17. double lg2(double n)  
  18. {  
  19.     return log(n)/log(2);  
  20. }  
  21. void Init()  
  22. {  
  23.     ifstream fin("data.txt");  
  24.     for(int i=0;i<14;i++)  
  25.     {  
  26.       for(int j=0;j<6;j++)  
  27.       {  
  28.           fin>>DataTable[i][j];  
  29.       }  
  30.     }  
  31.     fin.close();  
  32.     for(int i=1;i<=5;i++)  
  33.     {  
  34.         str2int[attrName[i]]=i;  
  35.         for(int j=0;j<attrCount[i];j++)  
  36.         {  
  37.             str2int[attrValue[i][j]]=j;  
  38.         }  
  39.     }  
  40.     for(int i=0;i<DataRow;i++)  
  41.       S.insert(i);  
  42.     for(int i=1;i<=4;i++)  
  43.       Attributes.insert(i);  
  44. }  
  45. double Entropy(const set<int> &s)  
  46. {  
  47.     double yes=0,no=0,sum=s.size(),ans=0;  
  48.     for(set<int>::iterator it=s.begin();it!=s.end();it++)  
  49.     {  
  50.         string s=DataTable[*it][str2int["PlayTennis"]];  
  51.         if(s=="Yes")  
  52.           yes++;  
  53.         else  
  54.           no++;  
  55.     }  
  56.     if(no==0||yes==0)  
  57.       return ans=0;  
  58.     ans=-yes/sum*lg2(yes/sum)-no/sum*lg2(no/sum);  
  59.     return ans;  
  60. }  
  61. double Gain(const set<int> & example,int attrid)  
  62. {  
  63.     int attrcount=attrCount[attrid];  
  64.     double ans=Entropy(example);  
  65.     double sum=example.size();  
  66.     set<int> * pset=new set<int>[attrcount];  
  67.     for(set<int>::iterator it=example.begin();it!=example.end();it++)  
  68.     {  
  69.         pset[str2int[DataTable[*it][attrid]]].insert(*it);  
  70.     }  
  71.     for(int i=0;i<attrcount;i++)  
  72.     {  
  73.         ans-=pset[i].size()/sum*Entropy(pset[i]);  
  74.     }  
  75.     return ans;  
  76. }  
  77. int FindBestAttribute(const set<int> & example,const set<int> & attr)  
  78. {  
  79.     double mx=0;  
  80.     int k=-1;  
  81.     for(set<int>::iterator i=attr.begin();i!=attr.end();i++)  
  82.     {  
  83.         double ret=Gain(example,*i);  
  84.         if(ret>mx)  
  85.         {  
  86.             mx=ret;  
  87.             k=*i;  
  88.         }  
  89.     }  
  90.     if(k==-1)  
  91.       cout<<"FindBestAttribute error!"<<endl;  
  92.     return k;  
  93. }  
  94. Node * Id3_solution(set<int> example,set<int> & attributes,Node * parent)  
  95. {  
  96.     Node *now=new Node;//建立樹節點。  
  97.     now->parentNode=parent;  
  98.     if(attributes.empty())//如果此時屬性清單已用完,即為空,則傳回。  
  99.       return now;  
  100.     int yes=0,no=0,sum=example.size();  
  101.     for(set<int>::iterator it=example.begin();it!=example.end();it++)  
  102.     {  
  103.         string s=DataTable[*it][str2int["PlayTennis"]];  
  104.         if(s=="Yes")  
  105.           yes++;  
  106.         else  
  107.           no++;  
  108.     }  
  109.     if(yes==sum||yes==0)  
  110.     {  
  111.         now->value=yes/sum;  
  112.         return now;  
  113.     }  
  114.     int bestattrid=FindBestAttribute(example,attributes);  
  115.     now->attrid=bestattrid;  
  116.     attributes.erase(attributes.find(bestattrid));  
  117.     vector< set<int> > child=vector< set<int> >(attrCount[bestattrid]);  
  118.     for(set<int>::iterator i=example.begin();i!=example.end();i++)  
  119.     {  
  120.         int id=str2int[DataTable[*i][bestattrid]];  
  121.         child[id].insert(*i);  
  122.     }  
  123.     for(int i=0;i<child.size();i++)  
  124.     {  
  125.         Node * ret=Id3_solution(child[i],attributes,now);  
  126.         now->childNode.push_back(ret);  
  127.     }  
  128.     return now;  
  129. }  
  130. int main()  
  131. {  
  132.     Init();  
  133.     Node * Root=Id3_solution(S,Attributes,NULL);  
  134.     return 0;  

繼續閱讀