天天看點

numpy實作樸素貝葉斯

import numpy as np 

import matplotlib.pyplot as plt  

w=250

w11=2000

train=np.random.randint(-300,300,(w11,4))

train=train.astype(float)

train_lable=np.zeros(w)

test=np.random.randint(-300,300,(w,4))

test=test.astype(float)

test_lable=np.zeros(w)

for i in range(250):

    if 1*test[i,0]+2*test[i,1]+3*test[i,2]+4*test[i,3]-1>0:

        test_lable[i]=1

    else:

        test_lable[i]=-1        

w1=[]

w2=[]

sum=w

sum1=0

sum2=0

for i in range(w11):

    if 1*train[i,0]+2*train[i,1]+3*train[i,2]+4*train[i,3]-1>0:

        w1.append(train[i])

        sum1=sum1+1

        w2.append(train[i])

        sum2=sum2+1

py1=sum1/sum

py2=sum2/sum

w1=np.array(w1)

w2=np.array(w2)

mean11=w1[:,0].mean()

std11=w1[:,0].std()

mean12=w1[:,1].mean()

std12=w1[:,1].std()

mean13=w1[:,2].mean()

std13=w1[:,2].std()

mean14=w1[:,3].mean()

std14=w1[:,3].std()

mean21=w2[:,0].mean()

std21=w2[:,0].std()

mean22=w2[:,1].mean()

std22=w2[:,1].std()

mean23=w2[:,2].mean()

std23=w2[:,2].std()

mean24=w2[:,3].mean()

std24=w2[:,3].std()

acc=0

    p1=py1*np.exp(-(test[i,0]-mean11)*(test[i,0]-mean11)/2/std11/std11)/std11*np.exp(-(test[i,1]-mean12)*(test[i,1]-mean12)/2/std12/std12)/std12*np.exp(-(test[i,2]-mean13)*(test[i,2]-mean13)/2/std13/std13)/std13*np.exp(-(test[i,3]-mean14)*(test[i,3]-mean14)/2/std14/std14)/std14

    p2=py2*np.exp(-(test[i,0]-mean21)*(test[i,0]-mean21)/2/std21/std21)/std21*np.exp(-(test[i,1]-mean22)*(test[i,1]-mean22)/2/std22/std22)/std22*np.exp(-(test[i,2]-mean23)*(test[i,2]-mean23)/2/std23/std23)/std23*np.exp(-(test[i,3]-mean24)*(test[i,3]-mean24)/2/std24/std24)/std24

    if p1>p2:

        if(test_lable[i]==1):

            acc+=1

        if(test_lable[i]==-1):

print(acc/250)

繼續閱讀