天天看點

決策樹C4.5分類算法的C++實作

一、前言

       當年實習公司布置了一個任務讓寫一個決策樹,以前并未接觸資料挖掘的東西,但作為一個資料挖掘最基本的知識點,還是應該有所了解的。

  程式的源碼可以點選這裡進行下載下傳,下面簡要介紹一下決策樹以及相關算法概念。

  決策樹是一個預測模型;他代表的是對象屬性與對象值之間的一種映射關系。樹中每個節點表示某個對象,而每個分叉路徑則代表的某個可能的屬性值,而每個葉結點則對應從根節點到該葉節點所經曆的路徑所表示的對象的值。決策樹僅有單一輸出,若欲有複數輸出,可以建立獨立的決策樹以處理不同輸出。 資料挖掘中決策樹是一種經常要用到的技術,可以用于分析資料,同樣也可以用來作預測(就像上面的銀行官員用他來預測貸款風險)。從資料産生決策樹的機器學習技術叫做決策樹學習, 通俗說就是決策樹。(來自維基百科)

  1986年Quinlan提出了著名的ID3算法。在ID3算法的基礎上,1993年Quinlan又提出了C4.5算法。為了适應處理大規模資料集的需要,後來又提出了若幹改進的算法,其中SLIQ (super-vised learning in quest)和SPRINT (scalable parallelizableinduction of decision trees)是比較有代表性的兩個算法,此處暫且略過。

  本文實作了C4.5的算法,在ID3的基礎上計算資訊增益,進而更加準确的反應資訊量。其實通俗的說就是建構一棵權重的最短路徑Haffman樹,讓權值最大的節點為父節點。

二、基本概念

  下面簡要介紹一下ID3算法:

  ID3算法的核心是:在決策樹各級結點上選擇屬性時,用資訊增益(information gain)作為屬性的選擇标準,以使得在每一個非葉結點進行測試時,能獲得關于被測試記錄最大的類别資訊。

  其具體方法是:檢測所有的屬性,選擇資訊增益最大的屬性産生決策樹結點,由該屬性的不同取值建立分支,再對各分支的子集遞歸調用該方法建立決策樹結點的分支,直到所有子集僅包含同一類别的資料為止。最後得到一棵決策樹,它可以用來對新的樣本進行分類。

  某屬性的資訊增益按下列方法計算:

決策樹C4.5分類算法的C++實作

      資訊熵是香農提出的,用于描述資訊不純度(不穩定性),其計算公式是Info(D)。

  其中:Pi為子集合中不同性(而二進制分類即正樣例和負樣例)的樣例的比例;j是屬性A中的索引,D是集合樣本,Dj是D中屬性A上值等于j的樣本集合。

      這樣資訊收益可以定義為樣本按照某屬性劃分時造成熵減少的期望,可以區分訓練樣本中正負樣本的能力。資訊增益定義為結點與其子結點的資訊熵之差,公式為Gain(A)。

  ID3算法的優點是:算法的理論清晰,方法簡單,學習能力較強。其缺點是:隻對比較小的資料集有效,且對噪聲比較敏感,當訓練資料集加大時,決策樹可能會随之改變。

  C4.5算法繼承了ID3算法的優點,并在以下幾方面對ID3算法進行了改進:

  1) 用資訊增益率來選擇屬性,克服了用資訊增益選擇屬性時偏向選擇取值多的屬性的不足,公式為GainRatio(A);

  2) 在樹構造過程中進行剪枝;

  3) 能夠完成對連續屬性的離散化處理;

  4) 能夠對不完整資料進行處理。

  C4.5算法與其它分類算法如統計方法、神經網絡等比較起來有如下優點:産生的分類規則易于了解,準确率較高。其缺點是:在構造樹的過程中,需要對資料集進行多次的順序掃描和排序,因而導緻算法的低效。此外,C4.5隻适合于能夠駐留于記憶體的資料集,當訓練集大得無法在記憶體容納時程式無法運作。

決策樹C4.5分類算法的C++實作

三、資料集

實作的C4.5資料集合如下:

決策樹C4.5分類算法的C++實作

它記錄了再不同的天氣狀況下,是否出去覓食的資料。

四、程式代碼

  程式引入狀态樹作為統計和計算屬性的資料結構,它記錄了每次計算後,各個屬性的統計資料,其定義如下:

[cpp]  view plain copy print ?

決策樹C4.5分類算法的C++實作
決策樹C4.5分類算法的C++實作
  1. struct attrItem  
  2. {  
  3.    std::vector<int>  itemNum;  //itemNum[0] = itemLine.size()  
  4.                                //itemNum[1] = decision num  
  5.    set<int>          itemLine;  
  6. };  
  7. struct attributes  
  8. {  
  9.    string attriName;  
  10.    vector<double> statResult;  
  11.    map<string, attrItem*> attriItem;  
  12. };   
  13. vector<attributes*> statTree;  

決策樹節點資料結構如下:

[cpp]  view plain copy print ?

決策樹C4.5分類算法的C++實作
決策樹C4.5分類算法的C++實作
  1. struct TreeNode   
  2. {  
  3.     std::string               m_sAttribute;  
  4.     int                       m_iDeciNum;  
  5.     int                       m_iUnDecinum;  
  6.     std::vector<TreeNode*>    m_vChildren;      
  7. };  

程式源碼如下所示(程式中有詳細注解):

[cpp]  view plain copy print ?

決策樹C4.5分類算法的C++實作
決策樹C4.5分類算法的C++實作
  1. #include "DecisionTree.h"  
  2. int main(int argc, char* argv[]){  
  3.     string filename = "source.txt";  
  4.     DecisionTree dt ;  
  5.     int attr_node = 0;  
  6.     TreeNode* treeHead = nullptr;  
  7.     set<int> readLineNum;  
  8.     vector<int> readClumNum;  
  9.     int deep = 0;  
  10.     if (dt.pretreatment(filename, readLineNum, readClumNum) == 0)  
  11.     {  
  12.         dt.CreatTree(treeHead, dt.getStatTree(), dt.getInfos(), readLineNum, readClumNum, deep);  
  13.     }  
  14.     return 0;  
  15. }  
  16. int DecisionTree::pretreatment(string filename, set<int>& readLineNum, vector<int>& readClumNum)  
  17. {  
  18.     ifstream read(filename.c_str());  
  19.     string itemline = "";  
  20.     getline(read, itemline);  
  21.     istringstream iss(itemline);  
  22.     string attr = "";  
  23.     while(iss >> attr)  
  24.     {  
  25.         attributes* s_attr = new attributes();  
  26.         s_attr->attriName = attr;  
  27.         //初始化屬性名  
  28.         statTree.push_back(s_attr);  
  29.         //初始化屬性映射  
  30.         attr_clum[attr] = attriNum;  
  31.         attriNum++;  
  32.         //初始化可用屬性列  
  33.         readClumNum.push_back(0);  
  34.         s_attr = nullptr;  
  35.     }  
  36.     int i  = 0;  
  37.     //添加具體資料  
  38.     while(true)  
  39.     {  
  40.         getline(read, itemline);  
  41.         if(itemline == "" || itemline.length() <= 1)  
  42.         {  
  43.             break;  
  44.         }  
  45.         vector<string> infoline;  
  46.         istringstream stream(itemline);  
  47.         string item = "";  
  48.         while(stream >> item)  
  49.         {  
  50.             infoline.push_back(item);  
  51.         }  
  52.         infos.push_back(infoline);  
  53.         readLineNum.insert(i);  
  54.         i++;  
  55.     }  
  56.     read.close();  
  57.     return 0;  
  58. }  
  59. int DecisionTree::statister(vector<vector<string>>& infos, vector<attributes*>& statTree,   
  60.                             set<int>& readLine, vector<int>& readClumNum)  
  61. {  
  62.     //yes的總行數  
  63.     int deciNum = 0;  
  64.     //統計每一行  
  65.     set<int>::iterator iter_end = readLine.end();  
  66.     for (set<int>::iterator line_iter = readLine.begin(); line_iter != iter_end; ++line_iter)  
  67.     {  
  68.         bool decisLine = false;  
  69.         if (infos[*line_iter][attriNum - 1] == "yes")  
  70.         {  
  71.             decisLine = true;  
  72.             deciNum++;   
  73.         }  
  74.         //如果該列未被鎖定并且為屬性列,進行統計  
  75.         for (int i = 0; i < attriNum - 1; i++)  
  76.         {  
  77.             if (readClumNum[i] == 0)  
  78.             {  
  79.                 std::string tempitem = infos[*line_iter][i];  
  80.                 auto map_iter = statTree[i]->attriItem.find(tempitem);  
  81.                 //沒有找到  
  82.                 if (map_iter == (statTree[i]->attriItem).end())  
  83.                 {  
  84.                     //建立  
  85.                     attrItem* attritem = new attrItem();  
  86.                     attritem->itemNum.push_back(1);  
  87.                     decisLine ? attritem->itemNum.push_back(1) : attritem->itemNum.push_back(0);  
  88.                     attritem->itemLine.insert(*line_iter);  
  89.                     //建立屬性名->item映射  
  90.                     (statTree[i]->attriItem)[tempitem] = attritem;  
  91.                     attritem = nullptr;  
  92.                 }  
  93.                 else  
  94.                 {  
  95.                     (map_iter->second)->itemNum[0]++;  
  96.                     (map_iter->second)->itemLine.insert(*line_iter);  
  97.                     if(decisLine)  
  98.                     {  
  99.                         (map_iter->second)->itemNum[1]++;  
  100.                     }  
  101.                 }  
  102.             }  
  103.         }  
  104.     }  
  105.     return deciNum;  
  106. }  
  107. void DecisionTree::CreatTree(TreeNode* treeHead, vector<attributes*>& statTree, vector<vector<string>>& infos,   
  108.                              set<int>& readLine, vector<int>& readClumNum, int deep)  
  109. {  
  110.     //有可統計的行  
  111.     if (readLine.size() != 0)  
  112.     {  
  113.         string treeLine = "";  
  114.         for (int i = 0; i < deep; i++)  
  115.         {  
  116.             treeLine += "--";  
  117.         }  
  118.         //清空其他屬性子樹,進行遞歸  
  119.         resetStatTree(statTree, readClumNum);  
  120.         //統計目前readLine中的資料:包括統計哪幾個屬性、哪些行,  
  121.         //并生成statTree(由于公用一個statTree,所有用引用代替),并傳回目的資訊數  
  122.         int deciNum = statister(getInfos(), statTree, readLine, readClumNum);  
  123.         int lineNum = readLine.size();  
  124.         int attr_node = compuDecisiNote(statTree, deciNum, lineNum, readClumNum);//本條複制為局部變量  
  125.         //該列被鎖定  
  126.         readClumNum[attr_node] = 1;  
  127.         //建立樹根  
  128.         TreeNode* treeNote = new TreeNode();  
  129.         treeNote->m_sAttribute = statTree[attr_node]->attriName;  
  130.         treeNote->m_iDeciNum = deciNum;  
  131.         treeNote->m_iUnDecinum = lineNum - deciNum;  
  132.         if (treeHead == nullptr)  
  133.         {  
  134.             treeHead = treeNote; //樹根  
  135.         }  
  136.         else  
  137.         {  
  138.             treeHead->m_vChildren.push_back(treeNote); //子節點  
  139.         }  
  140.         cout << "節點-"<< treeLine << ">" << statTree[attr_node]->attriName    << endl;  
  141.         //從孩子分支進行遞歸  
  142.         for(map<string, attrItem*>::iterator map_iterator = statTree[attr_node]->attriItem.begin();  
  143.             map_iterator != statTree[attr_node]->attriItem.end(); ++map_iterator)  
  144.         {  
  145.             //列印分支  
  146.             int sum = map_iterator->second->itemNum[0];  
  147.             int deci_Num = map_iterator->second->itemNum[1];  
  148.             cout << "分支--"<< treeLine << ">" << map_iterator->first << endl;  
  149.             //遞歸計算、建立  
  150.             if (deci_Num != 0 && sum != deci_Num )  
  151.             {  
  152.                 //計算有效行數  
  153.                 set<int> newReadLineNum = map_iterator->second->itemLine;  
  154.                 //DFS  
  155.                 CreatTree(treeNote, statTree, infos, newReadLineNum, readClumNum, deep + 1);  
  156.             }  
  157.             else  
  158.             {  
  159.                 //建立葉子節點  
  160.                 TreeNode* treeEnd = new TreeNode();  
  161.                 treeEnd->m_sAttribute = statTree[attr_node]->attriName;  
  162.                 treeEnd->m_iDeciNum = deci_Num;  
  163.                 treeEnd->m_iUnDecinum = sum - deci_Num;  
  164.                 treeNote->m_vChildren.push_back(treeEnd);  
  165.                 //列印葉子  
  166.                 if (deci_Num == 0)  
  167.                 {  
  168.                     cout << "葉子---"<< treeLine << ">no" << endl;  
  169.                 }  
  170.                 else  
  171.                 {  
  172.                     cout << "葉子---"<< treeLine << ">yes" << endl;  
  173.                 }  
  174.             }  
  175.         }  
  176.         //還原屬性列可用性  
  177.         readClumNum[attr_node] = 0;  
  178.     }  
  179. }  
  180. int DecisionTree::compuDecisiNote(vector<attributes*>& statTree, int deciNum, int lineNum, vector<int>& readClumNum)  
  181. {  
  182.     double max_temp = 0;  
  183.     int max_attribute = 0;  
  184.     //總的yes行的資訊量  
  185.     double infoD = info_D(deciNum, lineNum);  
  186.     for (int i = 0; i < attriNum - 1; i++)  
  187.     {  
  188.         if (readClumNum[i] == 0)  
  189.         {  
  190.             double splitInfo = 0.0;  
  191.             //info  
  192.             double info_temp = Info_attr(statTree[i]->attriItem, splitInfo, lineNum);  
  193.             statTree[i]->statResult.push_back(info_temp);  
  194.             //gain  
  195.             double gain_temp = infoD - info_temp;  
  196.             statTree[i]->statResult.push_back(gain_temp);  
  197.             //split_info  
  198.             statTree[i]->statResult.push_back(splitInfo);  
  199.             //gain_info  
  200.             double temp = gain_temp / splitInfo;  
  201.             statTree[i]->statResult.push_back(temp);  
  202.             //得到最大值*/  
  203.             if (temp > max_temp)  
  204.             {  
  205.                 max_temp = temp;  
  206.                 max_attribute = i;  
  207.             }  
  208.         }  
  209.     }  
  210.     return max_attribute;  
  211. }  
  212. double DecisionTree::info_D(int deciNum, int sum)  
  213. {  
  214.     double pi = (double)deciNum / (double)sum;  
  215.     double result = 0.0;  
  216.     if (pi == 1.0 || pi == 0.0)  
  217.     {  
  218.         return result;  
  219.     }  
  220.     result = pi * (log(pi) / log((double)2)) + (1 - pi)*(log(1 - pi)/log((double)2));  
  221.     return -result;  
  222. }  
  223. double DecisionTree::Info_attr(map<string, attrItem*>& attriItem, double& splitInfo, int lineNum)  
  224. {  
  225.     double result = 0.0;  
  226.     for (map<string, attrItem*>::iterator item = attriItem.begin();  
  227.          item != attriItem.end();  
  228.          ++item  
  229.         )  
  230.     {  
  231.          double pi = (double)(item->second->itemNum[0]) / (double)lineNum;  
  232.          splitInfo += pi * (log(pi) / log((double)2));  
  233.          double sub_attr = info_D(item->second->itemNum[1], item->second->itemNum[0]);  
  234.          result += pi * sub_attr;  
  235.     }  
  236.     splitInfo = -splitInfo;  
  237.     return result;  
  238. }  
  239. void DecisionTree::resetStatTree(vector<attributes*>& statTree, vector<int>& readClumNum)  
  240. {  
  241.     for (int i = 0; i < readClumNum.size() - 1; i++)  
  242.     {  
  243.         if (readClumNum[i] == 0)  
  244.         {  
  245.             map<string, attrItem*>::iterator it_end = statTree[i]->attriItem.end();  
  246.             for (map<string, attrItem*>::iterator it = statTree[i]->attriItem.begin();  
  247.                 it != it_end; it++)  
  248.             {  
  249.                 delete it->second;  
  250.             }  
  251.             statTree[i]->attriItem.clear();  
  252.             statTree[i]->statResult.clear();  
  253.         }  
  254.     }  
  255. }  

五、結果分析

程式輸出結果為:

決策樹C4.5分類算法的C++實作

以圖形表示為:

決策樹C4.5分類算法的C++實作

六、小結:

  1、在設計程式時,對程式邏輯有時會發生混亂,·後者在紙上仔細畫了些草圖才解決這些問題,畫一個好圖可以有效的幫助你了解程式的流程以及邏輯脈絡,是需求分析時最為關鍵的基本功。

  2、在編寫程式之初,一直在糾結用什麼樣的資料結構,後來經過幾次在程式設計實作推敲,才确定最佳的資料結構,可見資料結構在程式中的重要性。

  3、決策樹的編寫,其實就是理論與實踐的相結合,雖然理論上比較簡單,但是實踐中卻會遇到這樣那樣的問題,而這些問題就是考驗一個程式員對最基本的資料結構、算法的了解和熟練程度,是以,勤學勤練基本功依然是關鍵。

  4、程式的效率還有待提高,歡迎各路高手指正。

http://blog.csdn.net/fy2462/article/details/31762429

繼續閱讀