天天看點

使用scikit-learn的svm進行分類(代碼分析)

基于SciPy的衆多分支版本中,最有名,也是專門面向機器學習的就是Scikit-learn。Scikit-learn項目最早由資料科學家 David Cournapeau 在 2007 年發起,需要NumPy和SciPy等其他包的支援,是Python語言中專門針對機器學習應用而發展起來的一款開源架構。Scikit-learn從來不做除機器學習領域之外的其他擴充,也從來不采用未經廣泛驗證的算法。

Scikit-learn的基本功能主要被分為六大部分:分類,回歸,聚類,資料降維,模型選擇和資料預處理。

我們今天在這裡隻說部分内容

分類是指識别給定對象的所屬類别,屬于監督學習的範疇,最常見的應用場景包括垃圾郵件檢測和圖像識别等。目前Scikit-learn已經實作的算法包括:支援向量機(SVM),最近鄰,邏輯回歸,随機森林,決策樹以及多層感覺器(MLP)神經網絡等等。

需要指出的是,由于Scikit-learn本身不支援深度學習,也不支援GPU加速,是以這裡對于MLP的實作并不适合于處理大規模問題。有相關需求的讀者可以檢視同樣對Python有良好支援的Keras和Theano等架構

資料降維是指使用主成分分析(PCA)、非負矩陣分解(NMF)或特征選擇等降維技術來減少要考慮的随機變量的個數,其主要應用場景包括可視化處理和效率提升。

模型選擇是指對于給定參數和模型的比較、驗證和選擇,其主要目的是通過參數調整來提升精度。目前Scikit-learn實作的子產品包括:格點搜尋,交叉驗證和各種針對預測誤差評估的度量函數。

資料預處理是指資料的特征提取和歸一化,是機器學習過程中的第一個也是最重要的一個環節。這裡歸一化是指将輸入資料轉換為具有零均值和機關權方差的新變量,但因為大多數時候都做不到精确等于零,是以會設定一個可接受的範圍,一般都要求落在0-1之間。而特征提取是指将文本或圖像資料轉換為可用于機器學習的數字變量。

1. 安裝:

之前已經搭建了基于anaconda虛拟環境的TensorFlow平台,安裝了python 3.6,NumPy,SciPy。

在虛拟環境下運作pip install -U scikit-learn

使用scikit-learn的svm進行分類(代碼分析)

2.跑樣例代碼:

https://scikit-learn.org/stable/auto_examples/index.html#general-examples 都在這個連結裡

比如Recognizing hand-written digits:

列印結果時顯示上面的注釋:

"""
================================
Recognizing hand-written digits
================================

An example showing how the scikit-learn can be used to recognize images of
hand-written digits.

This example is commented in the
:ref:`tutorial section of the user manual <introduction>`.

"""
print(__doc__)
           

導入所需子產品:

import matplotlib.pyplot as plt

# Import datasets, classifiers and performance metrics
from sklearn import datasets, svm, metrics
           

讀入資料

digits = datasets.load_digits()
           

資料是長這樣的,總共有1797張圖像,每張圖像8*8,還有對應的标簽:

使用scikit-learn的svm進行分類(代碼分析)

把标簽和資料程式設計一個list,并顯示前四個:

使用scikit-learn的svm進行分類(代碼分析)
使用scikit-learn的svm進行分類(代碼分析)

然後整理資料

# To apply a classifier on this data, we need to flatten the image, to
# turn the data in a (samples, feature) matrix:
n_samples = len(digits.images) #1797
data = digits.images.reshape((n_samples, -1))#(1797,64)
           

建立分類器,使用前一部分資料進行訓練分類:

# Create a classifier: a support vector classifier
classifier = svm.SVC(gamma=0.001)

# We learn the digits on the first half of the digits
classifier.fit(data[:n_samples // 2], digits.target[:n_samples // 2])
           

關于svm.SVC參數詳解:https://blog.csdn.net/github_39261590/article/details/75009069   。 gamma: float參數 預設為auto,核函數系數,隻對‘rbf’,‘poly’,‘sigmod’有效。如果gamma為auto,代表其值為樣本特征數的倒數,即1/n_features.

下面對後半部分資料進行預測:

# Now predict the value of the digit on the second half:
expected = digits.target[n_samples // 2:]
predicted = classifier.predict(data[n_samples // 2:])
           

列印和顯示:

print("Classification report for classifier %s:\n%s\n"
      % (classifier, metrics.classification_report(expected, predicted)))
print("Confusion matrix:\n%s" % metrics.confusion_matrix(expected, predicted))

images_and_predictions = list(zip(digits.images[n_samples // 2:], predicted))
for index, (image, prediction) in enumerate(images_and_predictions[:4]):
    plt.subplot(2, 4, index + 5)
    plt.axis('off')
    plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
    plt.title('Prediction: %i' % prediction)
           

顯示結果:

使用scikit-learn的svm進行分類(代碼分析)

最後的最後:

https://scikit-learn.org/stable/auto_examples/index.html#general-examples 

這個其他的代碼也可以一起看看哈,都是比較基礎的~

參考連結:

https://www.leiphone.com/news/201701/ZJMTak4Y8ch3Nwd0.html

繼續閱讀