天天看點

C4.5決策樹(Java實作)

說明

此前已經上傳了ID3決策樹的Java實作,C4.5整體架構與之相差不大。

可參考:http://blog.csdn.net/xiaohukun/article/details/78041676

此次将結點的實作由Dom4J改為自定義類實作,更加自由和輕便。

代碼已打包并上傳

代碼

資料仍采用ARFF格式

train.arff

@relation weather.symbolic 
@attribute outlook {sunny,overcast,rainy} 
@attribute temperature {hot,mild,cool} 
@attribute humidity {high,normal} 
@attribute windy {TRUE,FALSE} 
@attribute play {yes,no} 

@data 
sunny,hot,high,FALSE,no 
sunny,hot,high,TRUE,no 
overcast,hot,high,FALSE,yes 
rainy,mild,high,FALSE,yes 
rainy,cool,normal,FALSE,yes 
rainy,cool,normal,TRUE,no 
overcast,cool,normal,TRUE,yes 
sunny,mild,high,FALSE,no 
sunny,cool,normal,FALSE,yes 
rainy,mild,normal,FALSE,yes 
sunny,mild,normal,TRUE,yes 
overcast,mild,high,TRUE,yes 
overcast,hot,normal,FALSE,yes 
rainy,mild,high,TRUE,no
           

C4.5類(主類)

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.io.FileOutputStream;
import java.io.BufferedOutputStream;
import java.lang.Math.*;



public class DecisionTree {

    private ArrayList<String> train_AttributeName = new ArrayList<String>(); // 存儲訓練集屬性的名稱
    private ArrayList<ArrayList<String>> train_attributeValue = new ArrayList<ArrayList<String>>(); // 存儲訓練集每個屬性的取值
    private ArrayList<String[]> trainData = new ArrayList<String[]>(); // 訓練集資料 ,即arff檔案中的data字元串

    public static final String patternString = "@attribute(.*)[{](.*?)[}]";
    //正則表達,其中*? 表示重複任意次,但盡可能少重複,防止比對到更後面的"}"符号

    private int decatt; // 決策變量在屬性集中的索引(即類标所在列)
    private InfoGain infoGain;
    private TreeNode root;


    public void train(String data_path, String targetAttr){
        //模型初始化操作
        read_trainARFF(new File(data_path));
        //printData();
        setDec(targetAttr);
        infoGain=new InfoGain(trainData, decatt);

        //拼裝行與列
        LinkedList<Integer> ll=new LinkedList<Integer>(); //LinkList用于增删比ArrayList有優勢
        for(int i = ; i< train_AttributeName.size(); i++){
            if(i!=decatt) ll.add(i);  //防止類别變量不在最後一列發生錯誤
        }
        ArrayList<Integer> al=new ArrayList<Integer>();
        for(int i=;i<trainData.size();i++){
            al.add(i);
        }

        //建構決策樹
        root = buildDT("root", "null", al, ll);
        //剪枝
        cutBranch(root);
    }

    /**
     * 建構決策樹
     * @param fatherName 節點名稱
     * @param fatherValue 節點值
     * @param subset 資料行子集
     * @param subset 資料列子集
     * @return 傳回根節點
     */
    public TreeNode buildDT(String fatherName, String fatherValue, ArrayList<Integer> subset,LinkedList<Integer> selatt){
        TreeNode node=new TreeNode();
        Map<String,Integer> targetNum = infoGain.get_AttributeNum(subset,decatt);//計算類-頻率
        String targetValue=infoGain.get_targetValue(targetNum);//判定分類
        node.setTargetNum(targetNum);
        node.setAttributeName(fatherName);
        node.setAttributeValue(fatherValue);
        node.setTargetValue(targetValue);

        //終止條件為類标單一/樹深度達到特征長度(還有可能是資訊增益率不存在)
        if (infoGain.isPure(targetNum) | selatt.isEmpty() ) {
            node.setNodeType("leafNode");
            return node;
        }
        int maxIndex = infoGain.getGainRatioMax(subset,selatt);
        selatt.remove(new Integer(maxIndex));  //這樣可以remove object
        String childName = train_AttributeName.get(maxIndex);

        Map<String, ArrayList<Integer>> childSubset = infoGain.get_AttributeSubset(subset, maxIndex);
        ArrayList<TreeNode> childNode = new ArrayList<TreeNode>();
        for (String childValue : childSubset.keySet()){
            TreeNode child = buildDT(childName, childValue, childSubset.get(childValue), selatt);
            child.setFatherTreeNode(node);  //順序很重要:回溯
            childNode.add(child);
        }
        node.setChildTreeNode(childNode);
        return  node;
    }

    /**
     * 剪枝函數
     * @param node 判斷結點
     * @return 剪枝之後的葉子結點集
     */
    public ArrayList<int[]> cutBranch(TreeNode node){
        ArrayList<int[]> resultNum = new ArrayList<int[]>();
        if (node.getNodeType() =="leafNode"){
            int[] tempNum = get_leafNum(node);
            resultNum.add(tempNum);
            return resultNum;
        }else{
            int sumNum = ;
            double oldRatio = ;
            for (TreeNode child : node.getChildTreeNode()){
                for(int[] leafNum : cutBranch(child)){
                    resultNum.add(leafNum);
                    oldRatio +=  + leafNum[];
                    sumNum += leafNum[];
                }
            }
            double oldNum =oldRatio;
            oldRatio /= sumNum;
            double sd = Math.sqrt(sumNum*oldRatio*(-oldRatio));
            int temLeaf[] = get_leafNum(node);
            double newNum = temLeaf[] + ;
            if(newNum < oldNum + sd){//符合剪枝條件,剪枝并傳回本身
                node.setChildTreeNode(null);
                node.setNodeType("leafNode");
                resultNum.clear();
                resultNum.add(temLeaf);
            }//不符合剪枝條件,傳回葉子結點
            return resultNum;
        }
    }

    //獲得葉子結點的數目
    public int[] get_leafNum(TreeNode node){
        int[] resultNum= new int[];
        Map<String,Integer> targetNum = node.getTargetNum();
        int minNum = Integer.MAX_VALUE;
        int sumNum = ;
        for(int num : targetNum.values()){
            minNum = Integer.min(minNum, num);
            sumNum += num;
        }
        if (targetNum.size() == ) minNum = ;
        resultNum[] = minNum;
        resultNum[] = sumNum;
        return  resultNum;
    }

    /**
     * 讀取arff檔案,給attribute、attributevalue、data指派
     * @param file  傳入的檔案
     */
    public void read_trainARFF(File file) {
        try {
            FileReader fr = new FileReader(file);
            BufferedReader br = new BufferedReader(fr);
            String line;
            Pattern pattern = Pattern.compile(patternString);
            while ((line = br.readLine()) != null) {
                Matcher matcher = pattern.matcher(line);
                if (matcher.find()) {
                    train_AttributeName.add(matcher.group().trim()); //擷取第一個括号裡的内容
                    //涉及取值,盡量加.trim(),後面也可以看到,即使是換行符也可能會造成字元串不相等
                    String[] values = matcher.group().split(",");
                    ArrayList<String> al = new ArrayList<String>(values.length);
                    for (String value : values) {
                        al.add(value.trim());
                    }
                    train_attributeValue.add(al);
                } else if (line.startsWith("@data")) {
                    while ((line = br.readLine()) != null) {
                        if(line=="")
                            continue;
                        String[] row = line.split(",");
                        trainData.add(row);
                    }
                } else {
                    continue;
                }
            }
            br.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    /**
     * 列印Data
     */
    public void printData(){
        System.out.println("目前的ATTR為");
        for(String attr : train_AttributeName){
            System.out.print(attr+" ");
        }
        System.out.println();
        System.out.println("---------------------------------");
        System.out.println("目前的DATA為");
        for(String[] row: trainData){
            for (String value : row){
                System.out.print(value+" ");
            }
            System.out.println();
        }
        System.out.println("---------------------------------");
    }

    //将決策樹存儲到xml檔案中
    public void write_DecisionTree(String filename) {
        try {
            File file = new File(filename);
            if (!file.exists())
                file.createNewFile();
            FileOutputStream fs = new FileOutputStream(filename);
            BufferedOutputStream bos = new BufferedOutputStream(fs);
            write_Node(bos, root, "");
            bos.flush();
            bos.close();
            fs.close();
        }catch (IOException e){
            e.printStackTrace();
        }
    }

    private void write_Node(BufferedOutputStream bos, TreeNode node, String block){
        String outputWords1 = block + "<" + node.getAttributeName()+ " value=\"" + node.getAttributeValue() + "\"";
        String outputWords2;
        Map<String, Integer> targetNum = node.getTargetNum();
        for (String value : targetNum.keySet()){
            outputWords1 += " " + value + ":" + targetNum.get(value);
        }
        outputWords1 += ">";
        if(node.getNodeType()=="leafNode"){
            outputWords1 += node.getTargetValue();
            outputWords2 = "</" + node.getAttributeName() + ">" + "\n";
        }else{
            outputWords1 += "\n";
            outputWords2 = block + "</" + node.getAttributeName() + ">" + "\n";
        }

        try {
            bos.write(outputWords1.getBytes());
        }catch (IOException e){
            e.printStackTrace();
        }
        ArrayList<TreeNode> childNode=node.getChildTreeNode();
        if (childNode !=null){
            for (TreeNode child : childNode){
                write_Node(bos, child, block+"  ");
            }
        }

        try {
            bos.write(outputWords2.getBytes());
        }catch (IOException e){
            System.out.println(e.getMessage());
        }
    }

    //設定決策變量
    public void setDec(int n) {
        if (n <  || n >= train_AttributeName.size()) {
            System.err.println("決策變量指定錯誤。");
            System.exit();
        }
        decatt = n;
    }
    public void setDec(String targetAttr) {
        int n = train_AttributeName.indexOf(targetAttr);
        setDec(n);
    }



    public static void main(String[] args) {
        DecisionTree dt=new DecisionTree();
        dt.train("files/train.arff", "play");
        dt.write_DecisionTree("files/Tree.xml");
    }

}
           

節點類

import java.util.ArrayList;
import java.util.Map;

/**
 * 節點類
 */
public class TreeNode {

    private String nodeType;
    private String attributeName;
    private String attributeValue;
    private ArrayList<TreeNode> childTreeNode;
    private TreeNode fatherTreeNode;
    private Map<String,Integer> targetNum;
    private String targetValue;
    //private List<String> pathName;


    public TreeNode(){
    }

    public String getNodeType() {
        return nodeType;
    }

    public void setNodeType(String nodeType) {
        this.nodeType = nodeType;
    }

    public String getAttributeName() {
        return attributeName;
    }

    public void setAttributeName(String attributeName) {
        this.attributeName = attributeName;
    }

    public String getAttributeValue() {
        return attributeValue;
    }

    public void setAttributeValue(String attributeValue) {
        this.attributeValue = attributeValue;
    }

    public ArrayList<TreeNode> getChildTreeNode() {
        return childTreeNode;
    }

    public void setChildTreeNode(ArrayList<TreeNode> childTreeNode) {
        this.childTreeNode = childTreeNode;
    }

    public TreeNode getFatherTreeNode() {
        return fatherTreeNode;
    }

    public void setFatherTreeNode(TreeNode fatherTreeNode) {
        this.fatherTreeNode = fatherTreeNode;
    }

    public Map<String, Integer> getTargetNum() {
        return targetNum;
    }

    public void setTargetNum(Map<String, Integer> targetNum) {
        this.targetNum = targetNum;
    }

    public String getTargetValue() {
        return targetValue;
    }

    public void setTargetValue(String targetValue) {
        this.targetValue = targetValue;
    }
}
           

資訊熵相關類

import java.util.*;


/**
 * 資訊增益相關類
 */
public class InfoGain {
    private ArrayList<String[]> trainData;
    private int decatt;

    public InfoGain(ArrayList<String[]> trainData, int decatt){
        this.trainData=trainData;
        this.decatt=decatt;
    }


    /**
     * 計算資訊熵
     */
    public double getEntropy(Map<String, Integer> attributeNum){
        double entropy = ;
        int sum= ;
        for(int num:attributeNum.values()){
            sum+=num;
            entropy += (-) * num * Math.log(num+Double.MIN_VALUE)/Math.log(); //避免log1
        }
        entropy += sum * Math.log(sum+Double.MIN_VALUE)/Math.log();
        entropy /= sum;
        return entropy;
    }

    public double getEntropy(ArrayList<Integer> subset, int attributeIndex){
        Map<String, Integer> attributeNum = get_AttributeNum(subset,attributeIndex);
        double entropy = getEntropy(attributeNum);
        return entropy;
    }


    //資訊熵增益率相關
    public int getGainRatioMax(ArrayList<Integer> subset, LinkedList<Integer> selatt){
        //計算原資訊熵

        Map<String, Integer> old_TargetNum = get_AttributeNum(subset, decatt);
        double oldEntropy = getEntropy(old_TargetNum);
        double maxGainRatio=;
        int maxIndex=decatt;

        for(int attributeIndex: selatt){
            Map<String, ArrayList<Integer>> attributeSubset = get_AttributeSubset(subset, attributeIndex);

            int sum = ;
            double newEntropy = ;
            for(ArrayList<Integer> tempSubset: attributeSubset.values()){
                int num = tempSubset.size();
                sum += num;
                double tempEntropy = getEntropy(tempSubset,decatt);
                newEntropy += num * tempEntropy;
            }
            newEntropy /= sum;
            double tempGainRatio = (oldEntropy - newEntropy)/getEntropy(subset, attributeIndex);  //計算資訊增益率

            //如果資訊增益率為負,應該停止分支,此處避免麻煩沒有做進一步讨論。
            if(tempGainRatio > maxGainRatio){
                maxGainRatio = tempGainRatio;
                maxIndex = attributeIndex;
            }
        }
        return  maxIndex;
    }

    /**
     * 判斷分類是否唯一
     * @param targetNum 各類數目的map
     * @return 分類是否唯一辨別
     */
    public boolean isPure(Map<String,Integer> targetNum){
        if (targetNum.size()>){
            return  false;
        }
        return  true;
    }

    /**
     * 獲得對應資料子集的對應特征的值-頻率字典
     * @param subset 子集行數
     * @param attributeIndex 特征列
     * @return
     */
    public  Map<String,Integer> get_AttributeNum(ArrayList<Integer> subset, int attributeIndex ) {
        Map<String,Integer> attributeNum=new HashMap<String, Integer>();
        for (int subsetIndex : subset) {
            String value=trainData.get(subsetIndex)[attributeIndex];
            Integer count = attributeNum.get(value);//int無法使用count!=null
            attributeNum.put(value, count!=null ? ++count:);
        }
        return  attributeNum;
    }

    /**
     * 獲得資料在某一特征次元下的子集劃分
     * @param subset 原子集
     * @param attributeIndex 特征序号
     * @return 子集劃分map
     */
    public Map<String, ArrayList<Integer>> get_AttributeSubset(ArrayList<Integer> subset, int attributeIndex){
        Map<String, ArrayList<Integer>> attributeSubset=new HashMap<String, ArrayList<Integer>>();
        for (int subsetIndex : subset) {
            String value=trainData.get(subsetIndex)[attributeIndex];
            ArrayList<Integer> tempSubset = attributeSubset.get(value);
            if(tempSubset != null){
                tempSubset.add(subsetIndex);
            }else{
                tempSubset=new ArrayList<Integer>();
                tempSubset.add(subsetIndex);
            }
            attributeSubset.put(value,tempSubset);
        }
        return  attributeSubset;
    }

    /**
     * 根據類-數目,判讀分類結果
     * @param targetNum
     * @return
     */
    public String get_targetValue(Map<String,Integer> targetNum){

         int maxNum=;
         String targetValue="";
         for(String key: targetNum.keySet()){
             int tempNum=targetNum.get(key);
             if(tempNum>maxNum){
                 maxNum=tempNum;
                 targetValue=key;
             }
         }
         return targetValue;
    }
}
           

感受

決策樹屬于比較基本的分類算法,但是在編寫代碼的過程中,我對于疊代的運用和代碼實作有了更進一步地認識。

在C4.5中有兩塊工作比較重要和複雜,其一,自然是生成決策樹;其二,便是實作剪枝。

這二者都是通過疊代來實作的,并且都經曆了uptodown和downtoup,隻不過前者是在自上而下的過程中完成主要操作,回溯隻是用以獲得傳回的結點;而後者的自上而下隻是為了找到各個葉子結點,真正的剪枝工作是在回溯的過程實作的。

問題

此次的代碼中并沒有實作對連續特征的處理以及缺失值的處理。

後者根據具體的情況變化較大,而前者根據目前提供的函數應該可以比較友善的實作,也就不再浪費時間了,如果有親希望保證完整性,可以自行補充。

繼續閱讀