天天看点

weka之NB算法

@Override
    public void buildClassifier(Instances data) throws Exception 
    {
        //检测分类器能否处理数据
        getCapabilities().testWithFail(data);
        //删除具有类别缺失值的实例
        data=new Instances(data);
        data.deleteWithMissingClass();
        //保存类别的数量
        m_NumClasses=data.numClasses();
        //复制训练集
        m_Instances=new Instances(data);
        //如果指定,就对数据进行离散化
        if(m_UseDiscretization)
        {
            m_Disc=new weka.filters.supervised.attribute.Discretize();
            m_Disc.setInputFormat(data);
            m_Instances=weka.filters.Filter.useFilter(m_Instances, m_Disc);
        }
        else
        {
            m_Disc=null;
        }

        //为概率分布预留空间
        //类别条件概率分布P(X|Y)
        m_Distributions=new Estimator[m_Instances.numAttributes()-][m_Instances.numClasses()];
        //类别分布P(Y)
        m_ClassDistribution=new DiscreteEstimator(m_Instances.numClasses(), true);
        int attIndex=;
        Enumeration enumeration=m_Instances.enumerateAttributes();
        //循环处理每一个属性
        while(enumeration.hasMoreElements())
        {
            Attribute attribute=(Attribute) enumeration.nextElement();

            //如果属性是数值型,根据相邻值之间的差异,测定估计器数值精度
            double numPrecision=DEFAULT_NUM_PRECISION;
            if(attribute.type()==Attribute.NUMERIC)
            {
                //根据当前属性的值对数据集排序
                m_Instances.sort(attribute);
                //排序之后,当前属性缺失值的实例就排到最前
                //这样,判断第一个样本是否有缺失值,就知道整体样本是否有缺失值
                //如果有,就没有必要执行if后面的代码块
                if((m_Instances.numInstances()>) && !m_Instances.instance().isMissing(attribute))
                {
                    //lastVal为后一个实例的当前属性值
                    double lastVal=m_Instances.instance().value(attribute);
                    //currentVal,为每个实例的当前属性值,deltaSum为差值
                    double currentVal,deltaSum=;
                    //distinct为当前属性取不同值的数量
                    int distinct=;
                    for(int i=;i<m_Instances.numInstances();i++)
                    {
                        Instance currentInst=m_Instances.instance(i);
                        if(currentInst.isMissing(attribute))
                        {
                            break;
                        }
                        currentVal=currentInst.value(attribute);
                        //如果当前值与最后值不相等,则相减并将差值累加到deltaSum
                        if(currentVal!=lastVal)
                        {
                            deltaSum+=currentVal-lastVal;
                            lastVal=currentVal;
                            distinct++;
                        }
                    }
                    //最终的numPrecision就是deltaSum/distinct
                    if(distinct>)
                    {
                        numPrecision=deltaSum/distinct;
                    }
                }
            }

            //循环处理每一个类别标签
            for(int j=;j<m_Instances.numClasses();j++)
            {
                //判断当前属性的类型
                switch(attribute.type())
                {
                //如果为连续的数值型属性,根据是否使用核估计器的选项,选择构建Kernelstimator对象还是NormalEstimator对象
                //两者的构造函数都是使用numPrecision作为参数
                case Attribute.NUMERIC:
                    if(m_UseKernelEstimator)
                    {
                        m_Distributions[attIndex][j]=new KernelEstimator(numPrecision);
                    }
                    else
                    {
                        m_Distributions[attIndex][j]=new NormalEstimator(numPrecision);
                    }
                    break;
                case Attribute.NOMINAL:
                    m_Distributions[attIndex][j]=new DiscreteEstimator(attribute.numValues(), true);
                    break;
                default:
                    throw new Exception("Attribute type unkown to my NB");
                }
            }
            attIndex++;
        }

        //统计每一个实例
        Enumeration enumInsts=m_Instances.enumerateInstances();
        while (enumInsts.hasMoreElements()) 
        {
            Instance instance=(Instance) enumInsts.nextElement();
            //调用updateClassifier方法,用实例更新分离器
            updateClassifier(instance);
        }

        //节省空间
        m_Instances=new Instances(m_Instances,);
    }

    public void updateClassifier(Instance instance) 
    {
        if(!instance.classIsMissing())
        {
            Enumeration enumAtts=m_Instances.enumerateAttributes();
            int attIndex=;
            //循环处理没一个属性
            while (enumAtts.hasMoreElements()) 
            {
                Attribute attribute = (Attribute) enumAtts.nextElement();
                if(!instance.isMissing(attribute))
                {
                    //m_Distributons第一个下标记为当亲属性下标记,第二个下标为类别值
                    //统计样本实例对应类别属性值的分布
                    //调用Estimator的AddValue方法将新数据值加入到当前评估器中
                    m_Distributions[attIndex][(int)instance.classValue()].addValue(instance.value(attribute),
                            instance.weight());
                }
                attIndex++;
            }
            //统计类别分布
            m_ClassDistribution.addValue(instance.classValue(), instance.weight());
        }
    }

    public double[] distributionForInstance(Instance instance) throws Exception
    {
        //如果使用useSupervisedDiscretization选项,就对实例进行离散化
        if(m_UseDiscretization)
        {
            m_Disc.input(instance);
            instance=m_Disc.output();
        }
        //类别的概率P(Y)
        double probs[]=new double[m_NumClasses];
        //循环得到每个类别的概率
        for(int j=;j<m_NumClasses;j++)
        {
            probs[j]=m_ClassDistribution.getProbability(j);
        }
        Enumeration enumAtts=instance.enumerateAttributes();
        int attIndex=;
        //循环处理每个属性
        while(enumAtts.hasMoreElements())
        {
            Attribute attribute=(Attribute) enumAtts.nextElement();
            if(!instance.isMissing(attribute))
            {
                //temp为临时概率,max为当前最大概率
                double temp,max=;
                for (int j = ; j < m_NumClasses; j++)
                {
                    //计算每个类别的条件概率P(X|Y)
                    temp=Math.max(, Math.pow(m_Distributions[attIndex][j].getProbability(instance.value(attribute)), 
                            m_Instances.attribute(attIndex).weight()));
                    probs[j]*=temp;
                    //更新最大概率值
                    if(probs[j]>max)
                    {
                        max=probs[j];
                    }
                    if(Double.isNaN(probs[j]))
                    {
                        throw new Exception(
                                "Nan returned from estimator for atrribute "+
                                attribute.name()+":\n"+
                                m_Distributions[attIndex][j].toString());
                    }
                }
                if(max> && max<)
                {
                    //防止概率下溢的危险
                    for(int j=;j<m_NumClasses;j++)
                    {
                        probs[j]*=;
                    }
                }
            }
            attIndex++;
        }

        //概率规范化
        Utils.normalize(probs);
        return probs;
    }
           

继续阅读