天天看點

weka之Bagging的源碼分析及相關知識點

 Bagging的源碼分析及相關知識點

1、Bagging的構造函數:

        public Bagging() {

             m_Classifier = new weka.classifiers.trees.REPTree();

        }

2、Bagging的繼承關系及父類的主要屬性和方法

(以下為逐級單重繼承及抽象父類的一些重要屬性和方法)

   Bagging

(protected intm_BagSizePercent = 100;

    protected boolean m_CalcOutOfBag = false;

    protected double m_OutOfBagError;

    public voidbuildClassifier(Instances data) throwsException{}

    public double[] distributionForInstance(Instance instance) throwsException{}

    public static void main(String[] argv) {

        runClassifier(newBagging(), argv);

    }

繼承——>RandomizableIteratedSingleClassifierEnhancer

(protected int m_Seed = 1;)

繼承——>IteratedSingleClassifierEnhancer

(protected Classifier[] m_Classifiers;

    protected intm_NumIterations = 10;

    public voidbuildClassifier(Instancesdata) throws Exception{

        if (m_Classifier == null) {

             throw newException("A base classifier has not been specified!");

        }

        m_Classifiers = Classifier.makeCopies(m_Classifier,m_NumIterations);

     }

繼承——>SingleClassifierEnhancer

(protected Classifierm_Classifier = new ZeroR();

    public void setClassifier(ClassifiernewClassifier){}

    public Classifier getClassifier(){}

    protected String getClassifierSpec(){}

繼承——>Classifier

(protectedboolean m_Debug = false;

    publicabstract voidbuildClassifier(Instances data) throwsException;

    public double classifyInstance(Instance instance) throwsException{}

    public double[] distributionForInstance(Instance instance)throws Exception{}

    public static Classifier forName(StringclassifierName, String[] options) throws Exception{}

    public static Classifier makeCopy(Classifier model) throws Exception{}

    public static Classifier[] makeCopies(Classifier model, int num) throwsException{}

    protected static void runClassifier(Classifier classifier, String[]options){}

3、父類引用指向子類對象:多态、動态連結,向上轉型(插曲)

ZeroR——> Classifier

Protected  Classifier  m_Classifier  =  newZeroR();

對于多态,可以總結以下幾點:

Ø  使用父類類型的引用指向子類的對象;

Ø  該引用隻能調用父類中定義的方法和變量;

Ø  如果子類中重寫了父類中的一個方法,那麼在調用這個方法的時候,将會調用子類中的這個方法;(動态連接配接、動态調用)

Ø  變量不能被重寫(覆寫),”重寫“的概念隻針對方法,如果在子類中”重寫“了父類中的變量,那麼在編譯時會報錯。

一個父類類型的引用指向一個子類的對象既可以使用子類強大的功能,又可以抽取父類的共性,父類類型的引用可以調用父類中定義的所有屬性和方法,而對于子類中定義而父類中沒有的方法,父類引用是無法調用的;

那什麼是動态連結呢?當父類中的一個方法隻有在父類中定義而在子類中沒有重寫的情況下,才可以被父類類型的引用調用;對于父類中定義的方法,如果子類中重寫了該方法,那麼父類類型的引用将會調用子類中的這個方法,這就是動态連接配接。

注:當超類對象引用變量引用子類對象時,被引用對象的類型而不是引用變量的類型決定了調用誰的成員方法,但是這個被調用的方法必須是在超類中定義過的,也就是說被子類覆寫的方法。

4、abstract的用法(插曲)

²  abstract修飾類,會使這個類成為一個抽象類,這個類将不能生成對象執行個體,可以做為對象變量聲明的類型,也就是編譯時類型,抽象類就像當于一類的半成品,需要子類繼承并覆寫其中的抽象方法。

²  abstract修飾方法,會使這個方法變成抽象方法,聲明(定義)而沒有實作,實作部分以";"代替。需要子類繼承實作(覆寫)。

²  abstract修飾符在修飾類時必須放在類名前。

²  abstract修飾方法就是要求其子類覆寫(實作)這個方法。調用時可以以多态方式調用子類覆寫(實作)後的方法,也就是說抽象方法必須在其子類中實作,除非子類本身也是抽象類。

²  父類是抽象類,有抽象方法,子類繼承父類,并把父類中的所有抽象方法都實作(覆寫),抽象類中有構造方法,是子類在構造子類對象時需要調用的父類(抽象類)的構造方法。

5、Bagging運作過程剖析:

1)運作主函數runClassifier

public static void main(String[] argv) {
    runClassifier(new Bagging(), argv);
  }
           

2)構造函數選擇分類器

//Bagging構造函數預設選擇分類器REPTree
public Bagging() {
    m_Classifier = new weka.classifiers.trees.REPTree();
}
           

3)buildClassifier

    ① 處理資料  

public void buildClassifier(Instances data) throws Exception {

    // can classifier handle the data?
    getCapabilities().testWithFail(data);

    // remove instances with missing class
    data = new Instances(data);
    data.deleteWithMissingClass();
           

    ② 調用父類buildClassifier()方法

super.buildClassifier(data);  //調用父類的方法:【IteratedSingleClassifierEnhancer】中的【buildClassifier()】
           

         [此方法可以得到多個分類器m_Classifiers,分類器類型與m_Classifier一緻]

//父類IteratedSingleClassifierEnhancer.java中的buildClassifier過程

public void buildClassifier(Instances data) throws Exception {

    if (m_Classifier == null) {
      throw new Exception("A base classifier has not been specified!");
    }
    m_Classifiers = Classifier.makeCopies(m_Classifier, m_NumIterations);
  }
           

    ③在m_CalcOutOfBag為真且m_BagSizePercent = 100時,準備計算OOB 及抽樣資料

if (m_CalcOutOfBag && (m_BagSizePercent != 100)) {
      throw new IllegalArgumentException("Bag size needs to be 100% if "
          + "out-of-bag error is to be calculated!");
    }

    int bagSize = (int) (data.numInstances() * (m_BagSizePercent / 100.0));
    Random random = new Random(m_Seed);

    boolean[][] inBag = null;    //inBag:一行代表一種分類器的抽樣樣本情況
    if (m_CalcOutOfBag)
      inBag = new boolean[m_Classifiers.length][];   //[m_Classifiers.length]為Classifier[]數組長度,即分類器的個數

    for (int j = 0; j < m_Classifiers.length; j++) {
      Instances bagData = null;

      // create the in-bag dataset
      if (m_CalcOutOfBag) {       //計算OOB時,inBag為二維數組存取樣本采樣情況
        inBag[j] = new boolean[data.numInstances()];
        // bagData = resampleWithWeights(data, random, inBag[j]);
        bagData = data.resampleWithWeights(random, inBag[j]);   //[resampleWithWeights]有放回采樣資料
      } else {                    //不計算OOB時也沒必要用inBag了
        bagData = data.resampleWithWeights(random);             //[resampleWithWeights]有放回采樣資料
        if (bagSize < data.numInstances()) {
          bagData.randomize(random);
          Instances newBagData = new Instances(bagData, 0, bagSize);
          bagData = newBagData;
        }
      }

      if (m_Classifier instanceof Randomizable) {   //Randomizable接口:設定seed
        ((Randomizable) m_Classifiers[j]).setSeed(random.nextInt());
      }
           

    ④ 建構分類樹,選擇m_Classifier所為分類器的分類方法,預設為REPTree方法

// build the classifier
      m_Classifiers[j].buildClassifier(bagData);   //建構分類樹,調用m_Classifier所為分類器的buildClassifier()方法
    }
           

    ⑤ 計算OOB誤差情況

// calc OOB error?
    if (getCalcOutOfBag()) {
      double outOfBagCount = 0.0;
      double errorSum = 0.0;
      boolean numeric = data.classAttribute().isNumeric();

      for (int i = 0; i < data.numInstances(); i++) {
        double vote;
        double[] votes;
        if (numeric)
          votes = new double[1];  //數值型求均值,一個數組單元
        else
          votes = new double[data.numClasses()];  //枚舉型需要投票

        // determine predictions for instance
        int voteCount = 0;
        for (int j = 0; j < m_Classifiers.length; j++) {
          if (inBag[j][i])  //尋找未被抽到的樣本執行個體,用來計算OOB
            continue;

          voteCount++;
          // double pred = m_Classifiers[j].classifyInstance(data.instance(i));
          if (numeric) {   //數值型
            // votes[0] += pred;
            votes[0] += m_Classifiers[j].classifyInstance(data.instance(i));  //數值型直接把預測結果累加
          } else {
            // votes[(int) pred]++;
            double[] newProbs = m_Classifiers[j].distributionForInstance(data.instance(i));
            // average the probability estimates
            for (int k = 0; k < newProbs.length; k++) {
              votes[k] += newProbs[k];    //枚舉型要累加枚舉機率
            }
          }
        }

        // "vote"
        if (numeric) {
          vote = votes[0];
          if (voteCount > 0) {
            vote /= voteCount; // average  算數均值
          }
        } else {
          if (Utils.eq(Utils.sum(votes), 0)) {
          } else {
            Utils.normalize(votes);   //歸一化
          }
          vote = Utils.maxIndex(votes); // predicted class  選出最大的index
        }

        // error for instance
        outOfBagCount += data.instance(i).weight();  //累權重重
        if (numeric) {
          errorSum += StrictMath.abs(vote - data.instance(i).classValue())*data.instance(i).weight(); //累加錯誤偏差
        } else {
          if (vote != data.instance(i).classValue())
            errorSum += data.instance(i).weight();   //枚舉型對出錯進行計數
        }
      }

      m_OutOfBagError = errorSum / outOfBagCount;
    } else {
      m_OutOfBagError = 0;   //不計算OOB了
    }
  }
           

4)主要訓練過程在于bagging的基分類器,預設為REPTree

6、Bagging建立分類樹過程:

建分類器完整代碼:

/**
   * Bagging method.
   * 
   * @param data the training data to be used for generating the bagged
   *          classifier.
   * @throws Exception if the classifier could not be built successfully
   */
  @Override
  public void buildClassifier(Instances data) throws Exception {

    // can classifier handle the data?
    getCapabilities().testWithFail(data);

    // remove instances with missing class
    data = new Instances(data);
    data.deleteWithMissingClass();

    super.buildClassifier(data);  //調用父類的方法:【IteratedSingleClassifierEnhancer】中的【buildClassifier()】

    if (m_CalcOutOfBag && (m_BagSizePercent != 100)) {
      throw new IllegalArgumentException("Bag size needs to be 100% if "
          + "out-of-bag error is to be calculated!");
    }

    int bagSize = (int) (data.numInstances() * (m_BagSizePercent / 100.0));
    Random random = new Random(m_Seed);

    boolean[][] inBag = null;    //inBag:一行代表一種分類器的抽樣樣本情況
    if (m_CalcOutOfBag)
      inBag = new boolean[m_Classifiers.length][];   //[m_Classifiers.length]為Classifier[]數組長度,即分類器的個數

    for (int j = 0; j < m_Classifiers.length; j++) {
      Instances bagData = null;

      // create the in-bag dataset
      if (m_CalcOutOfBag) {       //計算OOB時,inBag為二維數組存取樣本采樣情況
        inBag[j] = new boolean[data.numInstances()];
        // bagData = resampleWithWeights(data, random, inBag[j]);
        bagData = data.resampleWithWeights(random, inBag[j]);   //[resampleWithWeights]有放回采樣資料
      } else {                    //不計算OOB時也沒必要用inBag了
        bagData = data.resampleWithWeights(random);             //[resampleWithWeights]有放回采樣資料
        if (bagSize < data.numInstances()) {
          bagData.randomize(random);
          Instances newBagData = new Instances(bagData, 0, bagSize);
          bagData = newBagData;
        }
      }

      if (m_Classifier instanceof Randomizable) {   //Randomizable接口:設定seed
        ((Randomizable) m_Classifiers[j]).setSeed(random.nextInt());
      }

      // build the classifier
      m_Classifiers[j].buildClassifier(bagData);   //建構分類樹,調用m_Classifier所為分類器的buildClassifier()方法
    }

    // calc OOB error?
    if (getCalcOutOfBag()) {
      double outOfBagCount = 0.0;
      double errorSum = 0.0;
      boolean numeric = data.classAttribute().isNumeric();

      for (int i = 0; i < data.numInstances(); i++) {
        double vote;
        double[] votes;
        if (numeric)
          votes = new double[1];  //數值型求均值,一個數組單元
        else
          votes = new double[data.numClasses()];  //枚舉型需要投票

        // determine predictions for instance
        int voteCount = 0;
        for (int j = 0; j < m_Classifiers.length; j++) {
          if (inBag[j][i])  //尋找未被抽到的樣本執行個體,用來計算OOB
            continue;

          voteCount++;
          // double pred = m_Classifiers[j].classifyInstance(data.instance(i));
          if (numeric) {   //數值型
            // votes[0] += pred;
            votes[0] += m_Classifiers[j].classifyInstance(data.instance(i));  //數值型直接把預測結果累加
          } else {
            // votes[(int) pred]++;
            double[] newProbs = m_Classifiers[j].distributionForInstance(data.instance(i));
            // average the probability estimates
            for (int k = 0; k < newProbs.length; k++) {
              votes[k] += newProbs[k];    //枚舉型要累加枚舉機率
            }
          }
        }

        // "vote"
        if (numeric) {
          vote = votes[0];
          if (voteCount > 0) {
            vote /= voteCount; // average  算數均值
          }
        } else {
          if (Utils.eq(Utils.sum(votes), 0)) {
          } else {
            Utils.normalize(votes);   //歸一化
          }
          vote = Utils.maxIndex(votes); // predicted class  選出最大的index
        }

        // error for instance
        outOfBagCount += data.instance(i).weight();  //累權重重
        if (numeric) {
          errorSum += StrictMath.abs(vote - data.instance(i).classValue())*data.instance(i).weight(); //累加錯誤偏差
        } else {
          if (vote != data.instance(i).classValue())
            errorSum += data.instance(i).weight();   //枚舉型對出錯進行計數
        }
      }

      m_OutOfBagError = errorSum / outOfBagCount;
    } else {
      m_OutOfBagError = 0;   //不計算OOB了
    }
  }
           

繼續閱讀