天天看點

以圖像分割為例淺談支援向量機(SVM)

1. 什麼是支援向量機?

  在機器學習中,分類問題是一種非常常見也非常重要的問題。常見的分類方法有決策樹、聚類方法、貝葉斯分類等等。舉一個常見的分類的例子。如下圖1所示,在平面直角坐标系中,有一些點,已知這些點可以分為兩類,現在讓你将它們分類。

以圖像分割為例淺談支援向量機(SVM)

(圖1)

顯然我們可以發現所有的點一類位于左下角,一類位于右上角。是以我們可以很自然将它們分為兩類,如圖2所示:紅色的點代表一類,藍色的點代表一類。

以圖像分割為例淺談支援向量機(SVM)

(圖2)

現在如果讓你用一條直線将這兩類點分開,這應該是一件非常容易的事情,比如如圖3所示的三條直線都可以辦到這點。

以圖像分割為例淺談支援向量機(SVM)

(圖3)

事實上,可以很容易發現,我們可以作無數條直線将這兩類點分開。這裡,我們不禁要問,是不是所有的直線分類的效果都一樣好呢?如果不是,那麼哪一條直線分類效果最好呢?評判的标準又是什麼?比如對于如圖4所示的兩條直線,\(line1\)和\(line2\),這兩條直線哪條分類效果更好呢?

以圖像分割為例淺談支援向量機(SVM)

(圖4)

直覺上可以發現,\(line1\)的分類效果要比\(line2\)更好的,這是因為\(line1\)幾乎位于這兩類點的中間,不偏向于任何一類點;而\(line2\)則偏向右上部分的點更多一些。如果這時又增加了一些點讓你将它們歸為這兩類,顯然\(line1\)要更“公正”一些,而\(line2\)則有可能将本來屬于右上類的點錯誤地歸為左下類。說到這裡,你可能會問,如何才能确定那個最佳分類的直線呢?其實這正是支援向量機(\(SVM,Support Vector Machine\))要解決的問題。

  更一般地情況下,如圖5所示,有時兩類點(圖5中紅色的點和藍色的點)是交錯分布的,“你中有我,我中有你”,根本不可能用一條直線分開,這個時候該怎麼辦呢?這也是支援向量機要解決的問題,而且是支援向量機的優勢所在。這類問題叫做非線性分類問題。

以圖像分割為例淺談支援向量機(SVM)

(圖5)

  說到這裡,你可能大概有些明白支援向量機是用來幹什麼的了。支援向量機的基本模型是定義在特征空間上的間隔最大的線性分類器。它是一種二分類模型。當采用了核技巧之後,支援向量機即可以用于非線性分類。不同類型的支援向量機解決不同的問題。

1.線性可分支援向量機:當訓練資料線性可分時,通過硬間隔最大化,學習一個線性可分支援向量機。

2. 線性支援向量機:當訓練資料近似可分時,通過軟間隔最大化,學習一個線性支援向量機。

3. 非線性支援向量機:當訓練資料線性不可分時,通過使用核技巧以及軟間隔最大化,學習一個非線性支援向量機。

  以上隻是對于支援向量機的最粗淺的說明,其實支援向量機内在的數學原理還是非常複雜的,其内容也十分豐富。我在學習的過程中參考了不少教材,比如《資料挖掘導論》、《神經網絡與機器學習》、《Python大戰機器學習》等。裡面對于支援向量機有非常詳細的說明,而且還從數學的角度推導了一遍。個人覺得好好研究一下原理以及數學推導對于深刻了解支援向量機還是非常有幫助的。鑒于我這裡隻是介紹,而非嚴格地教程,是以公式就不羅列了,感興趣的請自行閱讀相關文獻與書籍。

2. 如何了解支援向量機?

  如果不從數學公式的角度出發,在不涉及公式細節的情況下,如何直覺了解支援向量機呢?雖然這并非易事(因為支援向量機的複雜性),但是還是可以辦到的。我在查閱資料的過程中,看到了知乎上的一個問題,裡面有幾個答案我覺得非常棒,可以讓你在不了解數學公式的情況下,對于支援向量機有一個直覺的了解。位址如下:

支援向量機(SVM)是什麼意思?

。這裡我仍然以兩類點的分類問題為例來談談我自己的了解。以圖1中的兩類點為例,前面我們已經說過了,存在無窮多條直線可以将這兩類點分開。現在我們的目标是在一定的準則下,找出劃分最好的那一條。從直覺的了解來看,這條最佳直線應該滿足“公正性”:即不偏向任何一類點,或者說處于中間位置。現在假設我們已經找到了一條分割直線\(l\),每一個樣本點都到這條直線存在一個距離。設直線\(l\)的方程為:\(wx + b = 0\),共有\(n\)個點,\(n\)個點的坐标為\((x_1,y_1),(x_2,y_2),\cdots,(x_n,y_n)\),\(n\)個點到直線\(l\)的距離分别為\(d_1,d_2,\cdots,d_n\),現在我們需要找\(d_1,d_2,\cdots,d_n\)中的最小值:\(d_{min} = min\{d_1,d_2,\cdots,d_n\}\),顯然我們希望\(d_{min}\)越大越好,\(d_{min}\)越大,說明它距離兩類的距離都較遠。于是問題轉化為在所有可行的直線劃分中,找到 使得\(d_{min}\)最大的那條即是最佳劃分直線。對于線性可分的情況而言,我們可以證明,這樣的最佳直線總是存在的。我們稱找到的最佳劃分兩類的直線為:最大幾何間隔分離超平面(對于二維點而言是直線,三維空間中則是平面,更高維則是超平面了,這裡統稱為超平面)。

什麼是支援向量?

  支援向量機(\(SVM\))之是以稱之為支援向量機,是因為有一個叫作支援向量(\(Support Vector\))的東西。那麼什麼叫作支援向量呢?假設我們現在已經找到了最大幾何間隔分離超平面,容易了解,我們可以找到許多條與這條直線平行的直線,在所有平行的直線中,存在兩條直線,它們恰好可以劃分兩類點,所謂恰好是指,如果再平移哪怕一點點,就會不能正确劃分兩類點,這兩條臨界直線(超平面)被稱之為間隔邊界。對于線性可分的情況而言,我們可以證明,在樣本點中總會有一些樣本點落在間隔邊界上(但是對于線性不可分的情況,則未必如此),落在間隔邊界上的這些樣本點就被我們稱為支援向量。之是以被稱之為支援向量呢,是因為我們确定的最大幾何間隔分離超平面隻與這些支援向量有關,與其他的樣本點無關,也就是說哪怕你去掉再多非支援向量的點,最大幾何間隔分離超平面也一樣不變。這也就是支援向量機名字的來源。

支援向量機如何處理線性不可分的情況?

  這個問題其實涉及到\(SVM\)的核心了。在之前我們多次提到了一個詞:核技巧。那麼什麼是核技巧呢?首先,我們需要明确輸入空間與特征空間這兩個概念。所謂輸入空間就是我們定義樣本點的空間,由于問題線性不可分,是以我們無法用一個超平面将兩類點分開,但是我們總可以找到一個合适的超曲面将兩類點正确劃分。問題的關鍵就是找到這個超曲面。直接尋找顯然是很困難的,是以我們聰明的數學家就定義了一個映射,簡單來說就是從低維到高維的映射,研究發現,如果映射定義地恰當,則原來在低維線性不可分的問題,到了高維居然就線性可分了!這真的是一個讓人驚喜的發現。是以我們隻要在高維按照之前線性可分的情況去找最大幾何間隔分離超平面,找到之後,再還原到低維就可以了。理論上已經證明,在低維線性不可分的情況下,我們總可以找到合适的從低維到高維的映射,使得在高維線性可分。于是問題的關鍵就是找這個從低維到高維的映射了,這個其實就是核函數(核技巧)要幹的事情了。具體的定義較為複雜,這裡不展開了。在給定核函數的情況下,我們可以利用求解線性分類問題的方法來求解非線性分類問題的支援向量機,學習是隐式地在特征空間(也就是映射之後的高維空間)進行的,這被稱之為核技巧。在實際應用中,往往直接依賴經驗選擇核函數,然後再驗證其是有效的即可。常用的核函數有:多項式核函數、高斯核函數、sigmoid核函數等。

3. 支援向量機的實際應用舉例(附matlab代碼與Python代碼)

1. 将兩類點分類(二維平面)

  作為第一個例子,我們首先解決開頭提到的那個平面上兩類點的分類問題。我們找出最大幾何間隔分離超平面與支援向量,然後驗證該最佳超平面能否對新加入的點進行準确分類。這裡我們分别使用Matlab與Pyhton來實作這個例子。Matlab中的\(svmtrain\)、\(svmclassify\)函數以及Python sklearn(一個機器學習的庫)均對SVM有很好的支援。如果想要詳細了解二者的用法,對于Matlab可以直接檢視其幫助手冊,對于Pyhton則可以參考相關機器學習的書籍或者直接去看sklearn的網站學習。

Matlab 對兩類點分類的代碼:

% 使用SVM(支援向量機)分割兩類點并畫出圖形
XY1 = 2 + rand(100,2); % 随機産生100行2列在2-3之間的點
XY2 = 3+ rand(100,2);% 随機産生100行2列在3-4之間的點
XY = [XY1;XY2]; % 合并兩點
Classify =[zeros(100,1);ones(100,1)]; % 第一類點用0表示,第二類點用1表示
Sample = 2+ 2*rand(100,2); % 測試點
%figure(1);
%plot(XY1(:,1),XY1(:,2),'r*'); % 第一類點用紅色表示
%hold on;
%plot(XY2(:,1),XY2(:,2),'b*'); % 第二類點用藍色表示
% 訓練SVM
SVM = svmtrain(XY,Classify,'showplot',true);
% 給測試點分類,并作出最大間隔超平面(一條直線)
svmclassify(SVM,Sample,'showplot',true);            

得到結果如圖6所示:

以圖像分割為例淺談支援向量機(SVM)

(圖6)

圖6中的直線即是所求的最大幾何間隔分離超平面,畫黑圈的點為支援向量,而且可以看出其對新增加的點劃分得很好,這說明了SVM最大幾何間隔分離超平面分類的有效性。

再來看Python的代碼:

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time   : 2017/7/22 10:45
# @Author : Lyrichu
# @Email  : [email protected]
# @File   : svm_split_points.py
'''
@Description:使用svm對兩類點進行分類(線性可分)
'''
from __future__ import print_function
import numpy as np
import matplotlib.pyplot as plt
from sklearn.svm import LinearSVC # 導入SVM 線性分類器
XY1 = 2 + np.random.rand(100,2) # 100行2列在2到3之間的資料點
XY2 = 4 + np.random.rand(100,2) # 100行2列在4到5之間的資料點
XY = np.concatenate((XY1,XY2),axis=0)
test_data = 2 + 3*np.random.rand(100,2) # 測試資料,2-5之間
label = np.append(np.zeros(100),np.ones(100)) # XY1 用0标志,XY2用1标志
svm = LinearSVC()
svm.fit(XY,label)
predict_test =svm.predict(test_data) # 對測試資料進行預測
coef = svm.coef_ # 系數(w向量)
intercept = svm.intercept_ # 截距(b)
# print("coef:",coef)
# print("intercept:",intercept)
# print("predict_test:",predict_test)
sort1_index = predict_test == 0. # 測試資料屬于第一類的序号(bool 數組)
sort2_index = predict_test == 1. # 測試資料屬于第二類的序号(bool 數組)
test_sort1 = test_data[sort1_index,:] # 測試資料屬于第一類的點
test_sort2 = test_data[sort2_index,:] # 測試資料屬于第二類的點
# 最大間隔超平面的方程為:Wx + b = 0
# 畫圖
plt.plot(XY1[:,0],XY1[:,1],'r*',label='train data 1')
plt.plot(XY2[:,0],XY2[:,1],'b*',label='train data 2')
line_x = np.arange(2,5,0.01) # 直線x坐标
line_y = (coef[0,0]*line_x + intercept[0])/(-coef[0,1]) # 直線y坐标
# 畫出直線
plt.plot(line_x,line_y,'-')
# 畫出預測點
plt.plot(test_sort1[:,0],test_sort1[:,1],'r+',label='test data 1')
plt.plot(test_sort2[:,0],test_sort2[:,1],'b+',label='test data 2')
plt.legend(loc = 'best')
plt.show()           

結果如下圖7所示:

以圖像分割為例淺談支援向量機(SVM)

(圖7)

其中那條直線即是作出的最大幾何間隔分離超平面,train data 1 和 train data 2為第一、二類訓練資料,test data 1和 test data 2 為第一、二類測試資料。可以看出 SVM 分類的效果很好。

2. 将圖像中的某個物體從背景中分割出來(這裡以分割在湖中遊泳的鴨子為例)

  如圖8所示,湖面上有一隻鴨子,現在我們希望将鴨子從湖水(背景)中分割出來,該怎麼做呢?

以圖像分割為例淺談支援向量機(SVM)

如果你手中有類似PS這樣的軟體,完成這個任務應該并不困難,不就是摳圖麼!!!但是,摳圖需要我們自己手動找分割線啊,多麻煩呢,能不能讓計算機自動完成這個工作呢?當然是可以的,利用上面說的SVM就可以辦到。那麼該怎麼做呢?我們知道,彩色圖檔本質上是由一個一個的像素點組成的,每一個像素點由RGB三色組成,或者說本質上彩色圖像就是三維數組,而灰階圖像則是二維數組。如果我們将湖水和鴨子看做兩類物體,那麼現在的任務則是從整個圖像中将這兩類分割出來。顯然鴨子與湖水的界限并不是一條單純的直線,甚至有些地方是交雜在一起的,是以本質上這是一個非線性可分的問題。從圖中可以看出,鴨子的顔色偏黑色和灰色,摻雜有少量白色以及黃色(鴨腳),而湖水則是淺綠色的。是以我們可以以顔色為标準對二者進行分類,即以RGB為分類标準。為了使用SVM,首先我們需要選取訓練樣本,這裡就是找出典型的屬于鴨子的像素點RGB值(為一個長度為3的向量),和屬于湖水的RGB值。關于如何确定圖像上某一點的RGB值,有很多辦法,這裡我推薦使用一個名為Colorpix的小軟體,這個軟體隻有幾百kb,一個exe執行檔案,可以找出螢幕上任何一點的像素屬性,用起來很友善,如果要用,請大家自行搜尋。這裡我對于湖水和鴨子分别選取了10個像素點,這樣我就得到了一個20行3列的樣本資料(每一行是一個樣本,共有20個樣本)。将湖水的像素點标記為0,鴨子的像素點标記為1,這樣我們就可以得到長度為20的、前10個元素為0,後10個元素為1的向量。由于圖像原始資料為三維矩陣,比如設其次元為\((m,n,k)\),我們首先需要将其轉化為2維,即轉化為\((mn,k)\)的矩陣,然後使用線性不可分的SVM訓練樣本資料,接着使用訓練好的SVM對\((mn,k)\)矩陣進行歸類,我們得到一個長為\(mn\)的資料取0或者1的一維數組\(predict\),為0的部分就是代表對應的像素點判定為湖水了。接着将\(predict\)數組在行的方向上擴充為3列,即變為\((predict,predict,predict)\),擴充之後的矩陣次元為\((mn,k)\),再将其變回三維矩陣,即\((m,n,k)\)的矩陣。該矩陣與原始圖像三維矩陣對應,該矩陣資料點為\((0,0,0)\)的部分即判定為湖水,我們将圖像上該像素點的RGB值變為\((255,255,255)\)(白色),于是我們就可以得到去掉湖水(變為白色背景)的鴨子了。

  以上就是使用SVM将鴨子從湖水中分割出來的步驟了。下面給出代碼:

1. Matlab 代碼

% 使用SVM将鴨子從湖面分割
% 導入圖像檔案引導對話框
[filename,pathname,flag] = uigetfile('*.jpg','請導入圖像檔案');
Duck = imread([pathname,filename]);
%使用ColorPix軟體從圖上選取幾個湖面的代表性點的RGB的值
LakeTrainData = [147,168,125;151 173 124;143 159 112;150 168 126;...
    146 165 120;145 161 116;150 171 130;146 112 137;149 169 120;144 160 111];
% 從圖中選取幾個有代表性的鴨子點的RGB值
DuckTrainData = [81 76 82;212 202 193;177 159 157;129 112 105;167 147 136;...
    237 207 145;226 207 192;95 81 68;198 216 218;197 180 128];
% 屬于湖的點為0,鴨子的點為1
group = [zeros(size(LakeTrainData,1),1);ones(size(DuckTrainData,1),1)];
% 訓練得到支援向量分類機
LakeDuckSVM = svmtrain([LakeTrainData;DuckTrainData],group,'kernel_function','polynomial',...
    'polyorder',2);
[m,n,k] = size(Duck); % 圖像三維矩陣
% 将Duck轉化為雙精度的m*n行,3列的矩陣
Duck1 = double(reshape(Duck,m*n,k));
% 根據訓練得到的支援向量機對整個圖像像素點進行分類
IndDuck = svmclassify(LakeDuckSVM,Duck1);
% 屬于湖的點的邏輯數組
IndLake = ~IndDuck;
result = reshape([IndLake,IndLake,IndLake],[m,n,k]); % 與圖檔的維數對應
Duck2 = Duck;
Duck2(result)= 255; % 湖面的點變為白色
figure;
imshow(Duck2); % 顯示分割之後的圖像           

結果如圖8所示:

以圖像分割為例淺談支援向量機(SVM)

(圖8)

可以基本看到鴨子的輪廓了,但是鴨子身體中有很多小點被扣去了(屬于誤判為湖水),這種情況可以改變一些選取的像素點,或者增加一些樣本,可以優化分割的效果。

再來看Python的實作吧。

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time   : 2017/7/22 13:58
# @Author : Lyrichu
# @Email  : [email protected]
# @File   : svm_split_picture.py
'''
@Description:SVM 将在湖中的一隻鴨子與湖水分割出來
'''
from PIL import Image
import numpy as np
from sklearn.svm import SVC # 非線性 分類 SVM
pic = 'duck.jpg' # 鴨子圖檔
img = Image.open(pic)
img.show() # 顯示原始圖像
img_arr = np.asarray(img,np.float64)
# 選取湖面上的關鍵點RGB值(10個)
lake_RGB = np.array(
    [[147,168,125],[151,173,124],[143,159,112],[150,168,126],[146,165,120],
     [145,161,116],[150,171,130],[146,112,137],[149,169,120],[144,160,111]]
)
# 選取鴨子上的關鍵點RGB值(10個)
duck_RGB = np.array(
    [[81,76,82],[212,202,193],[177,159,157],[129,112,105],[167,147,136],
     [237,207,145],[226,207,192],[95,81,68],[198,216,218],[197,180,128]]
)
RGB_arr = np.concatenate((lake_RGB,duck_RGB),axis=0) # 按列拼接
# lake 用 0标記,duck用1标記
label = np.append(np.zeros(lake_RGB.shape[0]),np.ones(duck_RGB.shape[0]))
# 原本 img_arr 形狀為(m,n,k),現在轉化為(m*n,k)
img_reshape = img_arr.reshape([img_arr.shape[0]*img_arr.shape[1],img_arr.shape[2]])
svc = SVC(kernel='poly',degree=3) # 使用多項式核,次數為3
svc.fit(RGB_arr,label) # SVM 訓練樣本
predict = svc.predict(img_reshape) # 預測測試點
lake_bool = predict == 0. # 為湖面的序号(bool)
lake_bool = lake_bool[:,np.newaxis] # 增加一列(一維變二維)
lake_bool_3col = np.concatenate((lake_bool,lake_bool,lake_bool),axis=1) # 變為三列
lake_bool_3d = lake_bool_3col.reshape((img_arr.shape[0],img_arr.shape[1],img_arr.shape[2])) # 變回三維數組(邏輯數組)
img_arr[lake_bool_3d] = 255. # 将湖面像素點變為白色
img_split = Image.fromarray(img_arr.astype('uint8')) # 數組轉image
img_split.show() # 顯示分割之後的圖像
img_split.save('split_duck.jpg') # 儲存           

結果如圖9所示:

以圖像分割為例淺談支援向量機(SVM)

(圖9)

可以看出,圖9的效果要比圖8好很多,基本已經将湖水全部去除了,隻有少數點沒有去除,如果增加一些訓練樣本,訓練的效果應該會更好,大家有興趣的可以自己嘗試一下。不過我很奇怪的是,Matlab與pyhton我選取的像素點是一模一樣的,SVM訓練設定參數也是一樣的,為什麼python的效果要明顯好于Matlab呢?我沒有閱讀二者SVM的源碼,不好下結論,姑且認為是Python大法好吧!!!哈哈哈......

  以上就是主要要講的内容了。其實SVM在最近幾年神經網絡大火之前還是非常受歡迎的,不過現在做複雜分類(比如圖像分類,語音識别等)好像更傾向于神經網絡了,SVM的一個重大缺點就是其對于處理大規模資料不是很适合,因為其主流的算法複雜度都是\(O(n^2)\)的,不過其在高維資料以及規模适中的情況下做分類效果還是很不錯的。以後有機會再來和大家探讨深度學習以及神經網絡吧,目前正入坑中。。。

Reference

  1. 《資料挖掘概念與技術》
  2. 《神經網絡與機器學習》
  3. 《Python大戰機器學習》
  4. 《Matlab在數學模組化中的應用》

    特别感謝《Matlab在數學模組化中的應用》,圖像分割的那個例子Matlab代碼改編于此,Python代碼也是基于此書改寫的。

熱愛程式設計,熱愛機器學習!

github:http://www.github.com/Lyrichu

github blog:http://Lyrichu.github.io

個人部落格站點:http://www.movieb2b.com(不再維護)