天天看點

機器學習(10.3)--手寫數字識别的不同算法比較(3)--支援向量機(SVM)算法

在之前的文章中我并沒有寫SVM算法,主要原因在于這個雖然我知道SVM的基本原理,

但中間關鍵的最大化決策邊界的算法我寫不出來

隻能使用sklearn提供的方法來求得這個最大化決策邊界,不過還好,至少SVM的基本原理不難了解,

這篇文章,我将詳細說明SVM的基本原理,并用一個簡單的小例子(代碼段一)來測試,

同時代碼段二是用SVM來進行手寫數字識别,但這個算法非常的久,你可能會認為機器假死,

在沒有最後print運作計算都還在運作,有耐心的可以慢慢等待

關于使用的資料集,可參考

機器學習(10.1)--手寫數字識别的不同算法比較(1)--mnist資料集不同版本解析及平均灰階實踐

# -*- coding:utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm

#下面兩行,解決matplotlib中無法顯示中文的問題
from pylab import *  
mpl.rcParams['font.sans-serif'] = ['SimHei']  


clf_1_points=np.array([[1,1],[1,0.5],[1.5,1],[2,1]])
clf_2_points=np.array([[3,3],[5,5],[5,3]])

plt.title('第一步,初始化定義了N個點,分為兩類,一類用紅色表示,一類用藍色表示')#顯示在整個圖上方的标題
plt.axis([0,6,0,6])#定義坐标系的範圍
plt.scatter(clf_1_points[:,0],clf_1_points[:,1],c='r')
plt.scatter(clf_2_points[:,0],clf_2_points[:,1],c='b')
plt.show()

plt.title('第二步,找一條直線,能夠将這兩類區分開來,同時到這兩類最近的點的距離相同\n例如圖中,兩條虛線都穿過率直線最近的點,他們到直線的距離相等')
plt.axis([0,6,0,6])#定義坐标系的範圍
plt.scatter(clf_1_points[:,0],clf_1_points[:,1],c='r')
plt.scatter(clf_2_points[:,0],clf_2_points[:,1],c='b')
plt.plot([2.5,2.5],[0,6]) 
plt.plot([2.0,2.0],[0,6],ls='-.')
plt.plot([3.0,3.0],[0,6],ls='-.')
plt.show()

plt.title('但是,這樣的直線,有無數條....\n又但是,雖然直線有無數條,但直線所對應的平行線(虛線)之間的距離是不同的,')
plt.axis([0,6,0,6])#定義坐标系的範圍
plt.scatter(clf_1_points[:,0],clf_1_points[:,1],c='r')
plt.scatter(clf_2_points[:,0],clf_2_points[:,1],c='b')
plt.plot([4.3,0],[0,4.3]) 
plt.plot([2.5,2.5],[0,6]) 
plt.show()

plt.title('SVM的目的,就是找到平行線(虛線)之間的距離最大的那一組,稱為:最大化決策邊界\n這樣做的好處就是當在預測一個新的點的類别時,\n如果這個正好落在兩虛線範圍内時,有更大空間做出選擇')
plt.axis([0,6,0,6])#定義坐标系的範圍
plt.scatter(clf_1_points[:,0],clf_1_points[:,1],c='r')
plt.scatter(clf_2_points[:,0],clf_2_points[:,1],c='b')
plt.plot([2.0,2.0],[0,6],ls='-.')
plt.plot([3.0,3.0],[0,6],ls='-.')
plt.show()

#調整為sklearn svm所需要的資料格式,points為所有的點,labels為每個點所對應的坐标
points=np.row_stack([clf_1_points,clf_2_points])
labels=np.zeros(len(clf_1_points)).astype(np.int).tolist()+np.ones(len(clf_2_points)).astype(np.int).tolist()

clf = svm.SVC(kernel='linear')
clf.fit(points, labels)

print("最大化決策邊界(兩條平行線)穿過的點:"+str(clf.support_vectors_))

#求得直線
w = clf.coef_[0]
a = -w[0]/w[1]
xx = np.linspace(0, 6)
yy = a*xx - (clf.intercept_[0])/w[1]

#通過最大化決策邊界 穿過的clf.support_vectors_兩個點,及上面的a,求得兩條虛線
b = clf.support_vectors_[0]
yy_down = a*xx + (b[1] - a*b[0])
b = clf.support_vectors_[-1]
yy_up = a*xx + (b[1] - a*b[0])


plt.title('最終結果,\n直線y=ax+b中的a=%2f,b=%2f'%(a,(clf.intercept_[0])/w[1]))
plt.axis([0,6,0,6])#定義坐标系的範圍
plt.scatter(clf_1_points[:,0],clf_1_points[:,1],c='r')
plt.scatter(clf_2_points[:,0],clf_2_points[:,1],c='b')
plt.plot(xx, yy, 'k-')
plt.plot(xx, yy_down, 'k--')
plt.plot(xx, yy_up, 'k--')
plt.show()
           

用SVM來進行手寫數字識别

# -*- coding:utf-8 -*-
import pickle  
import gzip  
import numpy as np  
from sklearn import svm 

with gzip.open(r'mnist.pkl.gz', 'rb')  as f:
    training_data, validation_data, test_data = pickle.load(f,encoding='bytes') 

#如果用全部資料集,實在是太慢了,隻取了5000條的訓練集和1000條的測試集,
training_data=(training_data[0][0:5000],training_data[1][0:5000])
test_data=(test_data[0][0:1000],test_data[1][0:1000])

clf = svm.SVC()
clf.fit(training_data[0], training_data[1])
predications = clf.predict(test_data[0])
num_correct = sum(np.where(predications == test_data[1], 1, 0))

print ('共有測試資料%s條,正确%s條' % ( len(test_data[0]),num_correct))