天天看點

Weka開發——REPTree源代碼分析

2009-05-28 12:46:41|  分類: 機器學習|字号 訂閱

    如果你分析完了ID3,還想進一步學習,最好還是先學習REPTree,它沒有牽扯到那麼多類,兩個類完成了全部的工作,看起來比較清楚,J48雖然有很強的可擴充性,但是初看起來還是有些費力,REPTree也是我賣算法時(為了買一台運算能力強一點的計算機,我也不得不賺錢),順便分析的,但因為我以前介紹過J48了,重複的東西不想再次介紹了,如果有什麼不明白的,就把我兩篇寫的結合起來看吧。

    我們再次從buildClassifier開始。

Random random = new Random(m_Seed);

m_zeroR = null;

if (data.numAttributes() == 1) {

    m_zeroR = new ZeroR();

    m_zeroR.buildClassifier(data);

    return;

}

         如果就隻有一個屬性,也就是類别屬性,就用ZeroR分類器學習,ZeroR分類器傳回訓練集中出現最多的類别值,已經講過了Weka開發[15]。

// Randomize and stratify

data.randomize(random);

if (data.classAttribute().isNominal()) {

    data.stratify(m_NumFolds);

}

         randomize就是把data中的資料重排一下,如果類别屬性是離散值,那麼用stratify函數,stratify意思是分層,現在把這個函數列出來:

public void stratify(int numFolds) {

if (classAttribute().isNominal()) {

       // sort by class

       int index = 1;

       while (index < numInstances()) {

           Instance instance1 = instance(index - 1);

           for (int j = index; j < numInstances(); j++) {

              Instance instance2 = instance(j);

              if ((instance1.classValue() == instance2.classValue())

                     || (instance1.classIsMissing() && instance2

                                .classIsMissing())) {

                  swap(index, j);

                  index++;

              }

           }

           index++;

       }

       stratStep(numFolds);

    }

}

         上面這兩重循環,就是根據類别值進行冒泡。下面有調用了stratStep函數:

protected void stratStep(int numFolds) {

    FastVector newVec = new FastVector(m_Instances.capacity());

    int start = 0, j;

    // create stratified batch

    while (newVec.size() < numInstances()) {

       j = start;

       while (j < numInstances()) {

           newVec.addElement(instance(j));

           j = j + numFolds;

       }

       start++;

    }

    m_Instances = newVec;

}

         這裡我舉一個例子說明:j=0時,numFolds為10時,newVec加入的instance下标就為0,10,20…。這樣的好處就是我們把各種類别的樣本類似平均分布了。

// Split data into training and pruning set

Instances train = null;

Instances prune = null;

if (!m_NoPruning) {

    train = data.trainCV(m_NumFolds, 0, random);

    prune = data.testCV(m_NumFolds, 0);

} else {

    train = data;

}

關于trainCV這個就不講了,就是crossValidation的第0個訓練集作為這次的訓練集(train)。而作為剪枝的資料集prune為第0個測試集。

// Create array of sorted indices and weights

int[][] sortedIndices = new int[train.numAttributes()][0];

double[][] weights = new double[train.numAttributes()][0];

double[] vals = new double[train.numInstances()];

for (int j = 0; j < train.numAttributes(); j++) {

    if (j != train.classIndex()) {

       weights[j] = new double[train.numInstances()];

       if (train.attribute(j).isNominal()) {

           // Handling nominal attributes. Putting indices of

           // instances with missing values at the end.

           sortedIndices[j] = new int[train.numInstances()];

           int count = 0;

           for (int i = 0; i < train.numInstances(); i++) {

              Instance inst = train.instance(i);

              if (!inst.isMissing(j)) {

                  sortedIndices[j][count] = i;

                  weights[j][count] = inst.weight();

                  count++;

              }

           }

           for (int i = 0; i < train.numInstances(); i++) {

              Instance inst = train.instance(i);

              if (inst.isMissing(j)) {

                  sortedIndices[j][count] = i;

                  weights[j][count] = inst.weight();

                  count++;

              }

           }

       } else {

           // Sorted indices are computed for numeric attributes

           for (int i = 0; i < train.numInstances(); i++) {

              Instance inst = train.instance(i);

              vals[i] = inst.value(j);

           }

           sortedIndices[j] = Utils.sort(vals);

           for (int i = 0; i < train.numInstances(); i++) {

              weights[j][i] = train.instance(sortedIndices[j][i])

                     .weight();

           }

       }

    }

}

         sortedIndices表示第j屬性的第count個樣本下标是多少,weights表示第j個屬性第count個樣本的權重,如果j屬性是離散值,通過兩個for循環,在sortedIndices和weights中在j屬性上是缺失值的樣本就排在了後面。如果是連續值,那麼就把全部樣本j屬性值得到,再排序,最後記錄權重。

// Compute initial class counts

double[] classProbs = new double[train.numClasses()];

double totalWeight = 0, totalSumSquared = 0;

for (int i = 0; i < train.numInstances(); i++) {

    Instance inst = train.instance(i);

    if (data.classAttribute().isNominal()) {

       classProbs[(int) inst.classValue()] += inst.weight();

       totalWeight += inst.weight();

    } else {

       classProbs[0] += inst.classValue() * inst.weight();

       totalSumSquared += inst.classValue() * inst.classValue()

              * inst.weight();

       totalWeight += inst.weight();

    }

}

m_Tree = new Tree();

double trainVariance = 0;

if (data.classAttribute().isNumeric()) {

    trainVariance = m_Tree.singleVariance(classProbs[0],

           totalSumSquared, totalWeight) / totalWeight;

    classProbs[0] /= totalWeight;

}

         計算初始化類别機率,如果類别是離散值,classProbs中記錄的是屬性類别inst.classValue()的樣本權重之和,totalWeight是全部樣本權重和。如果類别是連續值,classProbs[0]中是權重乘以類别值,它還有一個totalSumSquared是類别值平方乘以權重之和。

         m_Tree是一個Tree對象,如果是連續值類别,用m_Tree的成員函數來計算trainVariance這個帶權重的方差,最後classProbs[0]相當于期望。

// Build tree

m_Tree.buildTree(sortedIndices, weights, train, totalWeight,

       classProbs, new Instances(train, 0), m_MinNum,

       m_MinVarianceProp * trainVariance, 0, m_MaxDepth);

    有長度限制,我拆成了兩部分。   

    好了,終于可以建樹了,除了VC,我還真沒怎麼見過這麼多參數。現在把它拆開分析:

// Store structure of dataset, set minimum number of instances

// and make space for potential info from pruning data

m_Info = header;

m_HoldOutDist = new double[data.numClasses()];

// Make leaf if there are no training instances

int helpIndex = 0;

if (data.classIndex() == 0) {

    helpIndex = 1;

}

if (sortedIndices[helpIndex].length == 0) {

    if (data.classAttribute().isNumeric()) {

       m_Distribution = new double[2];

    } else {

       m_Distribution = new double[data.numClasses()];

    }

    m_ClassProbs = null;

    return;

}

         m_Info儲存的是資料集的表頭結構,m_HoldOutDist後面會講到,是用于剪枝的。這面這個有點意思,helpIndex在類别index不是0的情況下是1,否則是0,因為sortedIndices中沒有類别列。初始化m_Distribution,如果是連續值,數組長度是2,第一個儲存方差,後面是樣本總權重。離散值不會說,當然是類别值個數。

double priorVar = 0;

if (data.classAttribute().isNumeric()) {

    // Compute prior variance

    double totalSum = 0, totalSumSquared = 0, totalSumOfWeights = 0;

    for (int i = 0; i < sortedIndices[helpIndex].length; i++) {

       Instance inst = data.instance(sortedIndices[helpIndex][i]);

       totalSum += inst.classValue() * weights[helpIndex][i];

       totalSumSquared += inst.classValue() * inst.classValue()

              * weights[helpIndex][i];

       totalSumOfWeights += weights[helpIndex][i];

    }

    priorVar = singleVariance(totalSum, totalSumSquared,

           totalSumOfWeights);

}

         這個就非常簡單了,如果類别是連續值。再說一下,這裡helpIndex無所謂,隻要不是類别index就好。totalSum是類别值與樣本權重的乘積和,totalSumSquared是類别值平方乘樣本權重和,totalSumOfWeights是權重和。這裡還是說一下,singleVariance就是變換後的期望計算公式。

// Check if node doesn't contain enough instances, is pure

// or the maximum tree depth is reached

m_ClassProbs = new double[classProbs.length];

System.arraycopy(classProbs, 0, m_ClassProbs, 0, classProbs.length);

if ((totalWeight < (2 * minNum))

       ||

       // Nominal case

       (data.classAttribute().isNominal() && Utils.eq(

              m_ClassProbs[Utils.maxIndex(m_ClassProbs)], Utils

                     .sum(m_ClassProbs)))

       ||

       // Numeric case

       (data.classAttribute().isNumeric() && ((priorVar / totalWeight)

 < minVariance))

       ||

       // Check tree depth

       ((m_MaxDepth >= 0) && (depth >= maxDepth))) {

    // Make leaf

    m_Attribute = -1;

    if (data.classAttribute().isNominal()) {

       // Nominal case

       m_Distribution = new double[m_ClassProbs.length];

       for (int i = 0; i < m_ClassProbs.length; i++) {

           m_Distribution[i] = m_ClassProbs[i];

       }

       Utils.normalize(m_ClassProbs);

    } else {

       // Numeric case

       m_Distribution = new double[2];

       m_Distribution[0] = priorVar;

       m_Distribution[1] = totalWeight;

    }

    return;

}

         先看一下不會再分裂的情況,第一種,總樣本權重還不到最小分裂樣本數的2倍(因為至少要分出來兩個子結點嘛),第二種,類别是離散值的情況下,如果樣本都屬于一個類别(以前講過為什麼)。第三種,類别是連續值的情況下,如果方差小于一個最小方差,最小方差是由一個定義的常數與總方差的積。最後一種如果超過了定義的樹的深度。

         如果是離散值,就将m_ClassProbs數組中的内容複制到m_Distribution中,再進行規範化,如果是連續值,把方差和總權重儲存。

// Compute class distributions and value of splitting

// criterion for each attribute

double[] vals = new double[data.numAttributes()];

double[][][] dists = new double[data.numAttributes()][0][0];

double[][] props = new double[data.numAttributes()][0];

double[][] totalSubsetWeights = new double[data.numAttributes()][0];

double[] splits = new double[data.numAttributes()];

if (data.classAttribute().isNominal()) {

    // Nominal case

    for (int i = 0; i < data.numAttributes(); i++) {

       if (i != data.classIndex()) {

           splits[i] = distribution(props, dists, i,

                  sortedIndices[i], weights[i],

                  totalSubsetWeights, data);

           vals[i] = gain(dists[i], priorVal(dists[i]));

       }

    }

} else {

    // Numeric case

    for (int i = 0; i < data.numAttributes(); i++) {

       if (i != data.classIndex()) {

           splits[i] = numericDistribution(props, dists, i,

                  sortedIndices[i], weights[i],

                  totalSubsetWeights, data, vals);

       }

    }

}

         這裡出現了一下ditribution函數,也是非常長,但是又很重要,是以我還是先介紹它:

double splitPoint = Double.NaN;

Attribute attribute = data.attribute(att);

double[][] dist = null;

int i;

if (attribute.isNominal()) {

    // For nominal attributes

    dist = new double[attribute.numValues()][data.numClasses()];

    for (i = 0; i < sortedIndices.length; i++) {

       Instance inst = data.instance(sortedIndices[i]);

       if (inst.isMissing(att)) {

           break;

       }

       dist[(int) inst.value(att)][(int) inst.classValue()] +=

           weights[i];

    }

}

         先講一下離散值的情況,實作與j48包下面的Distribution非常相似,dist第一維是屬性值,第二維是類别值,元素值是樣本權重累加值。

else {

    // For numeric attributes

    double[][] currDist = new double[2][data.numClasses()];

    dist = new double[2][data.numClasses()];

    // Move all instances into second subset

    for (int j = 0; j < sortedIndices.length; j++) {

        Instance inst = data.instance(sortedIndices[j]);

       if (inst.isMissing(att)) {

           break;

       }

       currDist[1][(int) inst.classValue()] += weights[j];

    }

    double priorVal = priorVal(currDist);

    System.arraycopy(currDist[1], 0, dist[1], 0, dist[1].length);

    // Try all possible split points

    double currSplit = data.instance(sortedIndices[0]).value(att);

    double currVal, bestVal = -Double.MAX_VALUE;

    for (i = 0; i < sortedIndices.length; i++) {

       Instance inst = data.instance(sortedIndices[i]);

       if (inst.isMissing(att)) {

           break;

       }

       if (inst.value(att) > currSplit) {

           currVal = gain(currDist, priorVal);

           if (currVal > bestVal) {

              bestVal = currVal;

              splitPoint = (inst.value(att) + currSplit) / 2.0;

              for (int j = 0; j < currDist.length; j++) {

                  System.arraycopy(currDist[j], 0, dist[j], 0,

                         dist[j].length);

              }

           }

       }

       currSplit = inst.value(att);

       currDist[0][(int) inst.classValue()] += weights[i];

       currDist[1][(int) inst.classValue()] -= weights[i];

    }

}

         不想講了,和J48也是一樣,先把樣本存在後一子結點中currDist[1],然後依次試屬性值,找到一個最好看分裂點。

// Compute weights

props[att] = new double[dist.length];

for (int k = 0; k < props[att].length; k++) {

    props[att][k] = Utils.sum(dist[k]);

}

if (!(Utils.sum(props[att]) > 0)) {

    for (int k = 0; k < props[att].length; k++) {

       props[att][k] = 1.0 / (double) props[att].length;

    }

} else {

    Utils.normalize(props[att]);

}

         props中儲存的就是第att個屬性的第k個屬性值的樣本權重之和。如果這個值不太于0,就給它指派為1除以這個屬性的全部可能取值。否則規範化。

// Distribute counts

while (i < sortedIndices.length) {

    Instance inst = data.instance(sortedIndices[i]);

    for (int j = 0; j < dist.length; j++) {

       dist[j][(int) inst.classValue()] += props[att][j]

              * weights[i];

    }

    i++;

}

// Compute subset weights

subsetWeights[att] = new double[dist.length];

for (int j = 0; j < dist.length; j++) {

    subsetWeights[att][j] += Utils.sum(dist[j]);

}

// Return distribution and split point

dists[att] = dist;

return splitPoint;

         i這裡初始是有确定屬性值與缺失值的分界下标值,開始一時頭暈還沒看出來,調試才看出來。如果有缺失值,就用每一個屬性值都加上相應的權重來代替。在att屬性上分裂,那種子結點的權重和為dist在j這種屬性取值上的和。最後把dist指派給dists[att],傳回分裂點。

         現在再跳回到buildTree函數,接着講gain函數就是計算資訊增益,不講了。numericDistribution還是這麼長,而且也差不多,也就算了吧。

// Find best attribute

m_Attribute = Utils.maxIndex(vals);

int numAttVals = dists[m_Attribute].length;

// Check if there are at least two subsets with

// required minimum number of instances

int count = 0;

for (int i = 0; i < numAttVals; i++) {

    if (totalSubsetWeights[m_Attribute][i] >= minNum) {

       count++;

    }

    if (count > 1) {

       break;

    }

}

         vals中資訊增益值,m_Attribute就是有最大資訊增益值的屬性下标,再下來看是否這個屬性可以分出兩個大于minNum樣本數的子結點。

// Any useful split found?

if ((vals[m_Attribute] > 0) && (count > 1)) {

    // Build subtrees

    m_SplitPoint = splits[m_Attribute];

    m_Prop = props[m_Attribute];

    int[][][] subsetIndices = new int[numAttVals][data

           .numAttributes()][0];

    double[][][] subsetWeights = new double[numAttVals][data

           .numAttributes()][0];

    splitData(subsetIndices, subsetWeights, m_Attribute,

           m_SplitPoint, sortedIndices, weights, data);

    m_Successors = new Tree[numAttVals];

    for (int i = 0; i < numAttVals; i++) {

       m_Successors[i] = new Tree();

       m_Successors[i].buildTree(subsetIndices[i],

              subsetWeights[i], data,

              totalSubsetWeights[m_Attribute][i],

              dists[m_Attribute][i], header, minNum, minVariance,

              depth + 1, maxDepth);

    }

} else {

    // Make leaf

    m_Attribute = -1;

}

         如果找到了可以分裂的屬性,那我們就可以建立了樹了,看起來亂七八糟很複雜的樣子,其實如果你把上面講的搞清楚了,這裡和ID3,J48沒有什麼差別。如果不能分裂,就把m_Attribute置1,标記一下。

// Normalize class counts

if (data.classAttribute().isNominal()) {

    m_Distribution = new double[m_ClassProbs.length];

    for (int i = 0; i < m_ClassProbs.length; i++) {

       m_Distribution[i] = m_ClassProbs[i];

    }

    Utils.normalize(m_ClassProbs);

} else {

    m_Distribution = new double[2];

    m_Distribution[0] = priorVar;

    m_Distribution[1] = totalWeight;

}

         這個其實沒什麼好講的,隻是指派到m_Distribution,建樹就已經講完了。但是在buildClassifier我們還剩下三行,是關于剪枝的,當時在介紹J48的時候,就沒有講,因為我不需要用那部分,當時也沒怎麼看。

// Insert pruning data and perform reduced error pruning

if (!m_NoPruning) {

    m_Tree.insertHoldOutSet(prune);

    m_Tree.reducedErrorPrune();

    m_Tree.backfitHoldOutSet(prune);

}

         如果非不剪枝,那麼就是剪枝了,先看第一個被調用的函數:

protected void insertHoldOutSet(Instances data) throws Exception {

    for (int i = 0; i < data.numInstances(); i++) {

       insertHoldOutInstance(data.instance(i), data.instance(i)

              .weight(), this);

    }

}

         prune資料集中的每一個樣本作為參數調用insertHoldOutInstance,它也有點長,把它一部分一部分列出來:

// Insert instance into hold-out class distribution

if (inst.classAttribute().isNominal()) {

    // Nominal case

    m_HoldOutDist[(int) inst.classValue()] += weight;

    int predictedClass = 0;

    if (m_ClassProbs == null) {

       predictedClass = Utils.maxIndex(parent.m_ClassProbs);

    } else {

       predictedClass = Utils.maxIndex(m_ClassProbs);

    }

    if (predictedClass != (int) inst.classValue()) {

       m_HoldOutError += weight;

    }

} else {

    // Numeric case

    m_HoldOutDist[0] += weight;

    double diff = 0;

    if (m_ClassProbs == null) {

       diff = parent.m_ClassProbs[0] - inst.classValue();

    } else {

       diff = m_ClassProbs[0] - inst.classValue();

    }

    m_HoldOutError += diff * diff * weight;

}

         看一下離散的情況,如果是離散類别,看它預測出的類别是否與真實類别相同,如果不同,就把樣本權重累加到m_HoldOutError上,其中==null的情況應該是這個葉子結點上曾經分的時候就沒樣本。在連續類别時,是把預測值與真實值的差的平方乘權重加到m_holdOutError上,

// The process is recursive

if (m_Attribute != -1) {

    // If node is not a leaf

    if (inst.isMissing(m_Attribute)) {

       // Distribute instance

       for (int i = 0; i < m_Successors.length; i++) {

           if (m_Prop[i] > 0) {

              m_Successors[i].insertHoldOutInstance(inst, weight

                     * m_Prop[i], this);

           }

       }

    } else {

       if (m_Info.attribute(m_Attribute).isNominal()) {

           // Treat nominal attributes

           m_Successors[(int) inst.value(m_Attribute)]

                  .insertHoldOutInstance(inst, weight, this);

       } else {

           // Treat numeric attributes

           if (inst.value(m_Attribute) < m_SplitPoint) {

              m_Successors[0].insertHoldOutInstance(inst, weight,

                      this);

           } else {

              m_Successors[1].insertHoldOutInstance(inst, weight,

                     this);

           }

       }

    }

}

         m_Attribute等于-1時就是葉子結點,前面已經講過了,如果是缺失值的情況,又是把所有可能算一遍(前兩天看論文,有一篇論文提到對缺失值的運作,在C4.5中占到了80%的時間)。如果不是缺失值就遞歸。這個函數整體的含義就是計算父結點和子結點,為最後看分還是不分好做準備。

         好了,看第二個函數:

protected double reducedErrorPrune() throws Exception {

    // Is node leaf ?

    if (m_Attribute == -1) {

       return m_HoldOutError;

    }

    // Prune all sub trees

    double errorTree = 0;

    for (int i = 0; i < m_Successors.length; i++) {

       errorTree += m_Successors[i].reducedErrorPrune();

    }

    // Replace sub tree with leaf if error doesn't get worse

    if (errorTree >= m_HoldOutError) {

       m_Attribute = -1;

       m_Successors = null;

       return m_HoldOutError;

    } else {

       return errorTree;

    }

}

         如果開始就是葉子結點,太不可思議了,直接傳回。接下來,這是一個遞歸,遞歸就在做一件事情,如果幾個子結點的錯誤加起來比父結點還高,意思也就是說分裂比不分裂還要差,那麼我們就把子結點剪去,也就是剪枝,在這裡是剪葉子?剪枝的時候,設定m_Attribute,然後把子結點置空,傳回父結點的錯誤值。

         最後一個函數:

protected void backfitHoldOutSet(Instances data) throws Exception {

    for (int i = 0; i < data.numInstances(); i++) {

       backfitHoldOutInstance(data.instance(i), data.instance(i)

              .weight(), this);

    }

}

         backfitHoldOutInstance不難,但是還有有點長,分開貼出來:

// Insert instance into hold-out class distribution

if (inst.classAttribute().isNominal()) {

    // Nominal case

    if (m_ClassProbs == null) {

       m_ClassProbs = new double[inst.numClasses()];

    }

    System.arraycopy(m_Distribution, 0, m_ClassProbs, 0, inst

           .numClasses());

    m_ClassProbs[(int) inst.classValue()] += weight;

    Utils.normalize(m_ClassProbs);

} else {

    // Numeric case

    if (m_ClassProbs == null) {

       m_ClassProbs = new double[1];

    }

    m_ClassProbs[0] *= m_Distribution[1];

    m_ClassProbs[0] += weight * inst.classValue();

    m_ClassProbs[0] /= (m_Distribution[1] + weight);

}

         這個函數主要是把以前用訓練集測出來的值,現在把剪枝集的樣本資訊也加進去。這些以前也都講過。

// The process is recursive

if (m_Attribute != -1) {

    // If node is not a leaf

    if (inst.isMissing(m_Attribute)) {

       // Distribute instance

       for (int i = 0; i < m_Successors.length; i++) {

           if (m_Prop[i] > 0) {

              m_Successors[i].backfitHoldOutInstance(inst, weight

                     * m_Prop[i], this);

           }

       }

    } else {

       if (m_Info.attribute(m_Attribute).isNominal()) {

           // Treat nominal attributes

           m_Successors[(int) inst.value(m_Attribute)]

                  .backfitHoldOutInstance(inst, weight, this);

       } else {

           // Treat numeric attributes

           if (inst.value(m_Attribute) < m_SplitPoint) {

              m_Successors[0].backfitHoldOutInstance(inst,

                     weight, this);

           } else {

              m_Successors[1].backfitHoldOutInstance(inst,

                     weight, this);

           }

       }

    }

}

         不想講了,自己看吧,distributionForInstance也不講了,如果是一直看我的東西過來的,到現在還不明白,我也沒話說了。

繼續閱讀