天天看點

MATLAB實作LIBSVM中的c和g的參數尋優

引言:LIBSVM是台灣大學林智仁(Lin Chih-Jen)教授等開發設計的一個簡單、易于使用和快速有效的SVM模式識别與回歸的軟體包,他不但提供了編譯好的可在Windows系列系統的執行檔案,還提供了源代碼,友善改進、修改以及在其它作業系統上應用;該軟體對SVM所涉及的參數調節相對比較少,提供了很多的預設參數,利用這些預設參數可以解決很多問題;并提供了互動檢驗(Cross Validation)的功能。該軟體可以解決C-SVM、ν-SVM、ε-SVR和ν-SVR等問題,包括基于一對一算法的多類模式識别問題。

我們在進行科學研究的時候會經常使用SVM對資料進行分類,MATLAB自帶的SVM函數調參麻煩,且隻支援分類問題,不支援回歸問題。是以,林教授開發的功能更為強大的LIBSVM就是我們的不二選擇。

LIBSVM支援MATLAB、Python、C等編譯語言,今天我将講解的是在MATLAB環境下調用LIBSVM。

其中:LIBSVM工具包下載下傳:https://www.csie.ntu.edu.tw/~cjlin/libsvm/,具體安裝過程這裡就不詳細介紹了,大家可以參考其他部落格。

一:LIBSVM的使用

模型訓練:model = svmtrain(label,data,'libsvm_options');

模型預測:[predicted_label,accuary] = svmpredict(label_test,data_test,model,'libsvm_options')

其中libsvm_options為可選參數,其具體内容如下:

-s 設定svm類型:

 0 – C-SVC

 1 – v-SVC

 2 – one-class-SVM

 3 – ε-SVR

 4 – n – SVR

-t 設定核函數類型, 預設值為2

0 — 線性核: μ‘∗ν

1 — 多項式核:    (γ∗μ‘∗ν+coef0)degree

2 — RBF核: exp(–γ∗∥μ−ν∥2)

3 — sigmoid 核: tanh(γ∗μ‘∗ν+coef0)

-d degree: 核函數中的degree設定(針對多項式核函數)(預設3);

-g r(gama): 核函數中的gamma函數設定(針對多項式/rbf/sigmoid核函數)(預設1/ k);

-r coef0: 核函數中的coef0設定(針對多項式/sigmoid核函數)((預設0);

-c cost: 設定C-SVC, e -SVR和v-SVR的參數(損失函數)(預設1);

-n nu: 設定v-SVC, 一類SVM和v- SVR的參數(預設0.5);

-p p: 設定e -SVR 中損失函數p的值(預設0.1);

-m cachesize: 設定cache記憶體大小, 以MB為機關(預設40);

-e eps: 設定允許的終止判據(預設0.001);

-h shrinking: 是否使用啟發式, 0或1(預設1);

-wi weight: 設定第幾類的參數C為weight*C (C-SVC中的C) (預設1);

-v n: n-fold互動檢驗模式, n為fold的個數, 必須大于等于2;

-b 機率估計: 是否計算SVC或SVR的機率估計, 可選值0或1, 預設0;
           

例:

分類問題:

model = svmtrain(label_train,data_train,'-s 0 -t 2 -c 0.1 -g 0.1');

[predicted_label,accuray] = svmpredict(label_test,data_test,model)

回歸問題:

model = svmtrain(label_train,data_train,'-s 3 -t 2-c 0.1 -g 0.1 -p 0.01')

predicted_label = svmpredict(label_test,data_test,model)

二:c、g參數尋優問題

針對SVM中的參數優化,Python環境下的LIBSVM中有尋優函數grid.py幫助大家尋找最優的c和g:

MATLAB實作LIBSVM中的c和g的參數尋優

但是MATLAB環境下的LIBSVM卻沒有這個功能,是以今天這裡就給大家分享在MATLAB環境下實作LIBSVM參數c和g的自動尋優:

function [best_c,best_g,best_acc] = SvmSearchParas(data,label,c_max,c_min,c_step,g_max,g_min,g_step,v)
%--------------------------------------------------------------------------
%The function looks for the SVM's most important parameters c and g
%The Author:等等登登-Ande
%The Email:[email protected]
%The Blog:qq_35166974
%%
%Initialization parameter
if nargin < 9
    v = 10;
end
if nargin < 8
    v = 10;
    g_step = 1;
end
if nargin < 7
    v = 10;
    g_step = 1;
    c_step = 1;
end
if nargin < 6
    v = 10;
    g_step = 1;
    c_step = 1;
    g_min = -5;
end
if nargin < 5
    v = 10;
    g_step = 1;
    c_step = 1;
    g_min = -5;
    g_max = 5;
end
if nargin < 4
    v = 10;
    g_step = 1;
    c_step = 1;
    g_min = -5;
    g_max = 5;
    c_min = -5;
end
if nargin < 3
    v = 10;
    g_step = 1;
    c_step = 1;
    g_min = -5;
    g_max = 5;
    c_min = -5;
    c_max = 5;
end
if nargin < 2
    warning('You did not enter enough parameters!');
end
%%
%Parameter optimization
[mesh1,mesh2] = meshgrid(c_min:c_step:c_max,g_min:g_step:g_max);
[raw,col] = size(mesh1);
acc = zeros(raw,col);
for i=1:raw
    for j=1:col
        cg_paras = ['-v ',num2str(v),'-c ',num2str(2.^mesh1(i,j)),' ','-g ',num2str(2.^mesh2(i,j))];
        acc(i,j) = libsvmtrain(double(label),double(data),cg_paras);
    end
end
best_acc = max(max(acc));
[label_i,label_j] = find(acc==best_acc);
best_c = 2.^mesh1(label_i,label_j);
best_g = 2.^mesh2(label_i,label_j);
figure
mesh(mesh1,mesh2,acc);
xlabel('log2c');
ylabel('log2g');
zlabel('Accuracy')
           

繼續閱讀