天天看點

【計算機視覺(CV)】基于k-means實作鸢尾花聚類【計算機視覺(CV)】基于k-means實作鸢尾花聚類前言一、鸢尾花資料集描述二、資料集預處理三、模型訓練四、可視化展示六、模型預測總結

【計算機視覺(CV)】基于k-means實作鸢尾花聚類

作者簡介:在校大學生一枚,華為雲享專家,阿裡雲專家部落客,騰雲先鋒(TDP)成員,雲曦智劃項目總負責人,全國高等學校計算機教學與産業實踐資源建設專家委員會(TIPCC)志願者,以及程式設計愛好者,期待和大家一起學習,一起進步~

.

部落格首頁:ぃ靈彧が的學習日志

.

本文專欄:人工智能

.

專欄寄語:若你決定燦爛,山無遮,海無攔

.

【計算機視覺(CV)】基于k-means實作鸢尾花聚類【計算機視覺(CV)】基于k-means實作鸢尾花聚類前言一、鸢尾花資料集描述二、資料集預處理三、模型訓練四、可視化展示六、模型預測總結

(文章目錄)

前言

(一)、任務描述

對于給定的樣本集,按照樣本之間的距離大小,将樣本集劃分為K個簇,讓簇内的點盡量緊密的連在一起,而讓簇間的距離盡量的大

【計算機視覺(CV)】基于k-means實作鸢尾花聚類【計算機視覺(CV)】基于k-means實作鸢尾花聚類前言一、鸢尾花資料集描述二、資料集預處理三、模型訓練四、可視化展示六、模型預測總結

(二)、環境配置

本實踐代碼運作的環境配置如下:Python版本為3.7,PaddlePaddle版本為2.0.0,操作平台為AI Studio。大部分深度學習項目都要經過以下幾個過程:資料準備、模型配置、模型訓練、模型評估。

import paddle
import numpy as np
import matplotlib.pyplot as plt
print(paddle.__version__)

# cpu/gpu環境選擇,在 paddle.set_device() 輸入對應運作裝置。
# device = paddle.set_device('gpu')
           

一、鸢尾花資料集描述

1、包含3種類型資料集,共150條資料 ;2、包含4項特征:花萼長度、花萼寬度、花瓣長度、花瓣寬度

二、資料集預處理

本案例主要分以下幾個步驟進行資料預處理:

(1)解壓原始資料集

(2)按照比例劃分訓練集與驗證集

(3)亂序,生成資料清單

(4)定義資料讀取器,轉換圖檔

(一)、導入相關包

首先我們引入本案例需要的所有子產品

import matplotlib.pyplot as plt
import numpy as np
from sklearn.cluster import KMeans 
from sklearn import datasets 
           

(二)、加載資料集

# 直接從sklearn中擷取資料集
iris = datasets.load_iris()
X = iris.data[:, :4]    # 表示我們取特征空間中的4個次元
print(X.shape)
           

(三)、繪制二維資料分布圖

每個樣本使用兩個特征,繪制其二維資料分布圖

# 取前兩個次元(萼片長度、萼片寬度),繪制資料分布圖
plt.scatter(X[:, 0], X[:, 1], c="red", marker='o', label='see')
plt.xlabel('sepal length')
plt.ylabel('sepal width')
plt.legend(loc=2)
plt.show() 

           

輸出結果如下圖1所示:

【計算機視覺(CV)】基于k-means實作鸢尾花聚類【計算機視覺(CV)】基于k-means實作鸢尾花聚類前言一、鸢尾花資料集描述二、資料集預處理三、模型訓練四、可視化展示六、模型預測總結

(四)、執行個體化K-means類,并且定義訓練函數

def Model(n_clusters):
    estimator = KMeans(n_clusters=n_clusters)# 構造聚類器
    return estimator

def train(estimator):
    estimator.fit(X)  # 聚類
           

三、模型訓練

# 初始化執行個體,并開啟訓練拟合
estimator=Model(3)     
train(estimator)     
           

四、可視化展示

label_pred = estimator.labels_  # 擷取聚類标簽
# 繪制k-means結果
x0 = X[label_pred == 0]
x1 = X[label_pred == 1]
x2 = X[label_pred == 2]
plt.scatter(x0[:, 0], x0[:, 1], c="red", marker='o', label='label0')
plt.scatter(x1[:, 0], x1[:, 1], c="green", marker='*', label='label1')
plt.scatter(x2[:, 0], x2[:, 1], c="blue", marker='+', label='label2')
plt.xlabel('sepal length')
plt.ylabel('sepal width')
plt.legend(loc=2)
plt.show() 
           

輸出結果如下圖2所示:

【計算機視覺(CV)】基于k-means實作鸢尾花聚類【計算機視覺(CV)】基于k-means實作鸢尾花聚類前言一、鸢尾花資料集描述二、資料集預處理三、模型訓練四、可視化展示六、模型預測總結
# 法一:直接手寫實作
# 歐氏距離計算
def distEclud(x,y):
    return np.sqrt(np.sum((x-y)**2))  # 計算歐氏距離
 
# 為給定資料集建構一個包含K個随機質心centroids的集合
def randCent(dataSet,k):
    m,n = dataSet.shape #m=150,n=4
    centroids = np.zeros((k,n)) #4*4
    for i in range(k): # 執行四次
        index = int(np.random.uniform(0,m)) # 産生0到150的随機數(在資料集中随機挑一個向量做為質心的初值)
        centroids[i,:] = dataSet[index,:] #把對應行的四個次元傳給質心的集合
    return centroids
 
# k均值聚類算法
def KMeans(dataSet,k):
    m = np.shape(dataSet)[0]  #行數150
    # 第一列存每個樣本屬于哪一簇(四個簇)
    # 第二列存每個樣本的到簇的中心點的誤差
    clusterAssment = np.mat(np.zeros((m,2)))# .mat()建立150*2的矩陣
    clusterChange = True
 
    # 1.初始化質心centroids
    centroids = randCent(dataSet,k)#4*4
    while clusterChange:
        # 樣本所屬簇不再更新時停止疊代
        clusterChange = False
 
        # 周遊所有的樣本(行數150)
        for i in range(m):
            minDist = 100000.0
            minIndex = -1
 
            # 周遊所有的質心
            #2.找出最近的質心
            for j in range(k):
                # 計算該樣本到4個質心的歐式距離,找到距離最近的那個質心minIndex
                distance = distEclud(centroids[j,:],dataSet[i,:])
                if distance < minDist:
                    minDist = distance
                    minIndex = j
            # 3.更新該行樣本所屬的簇
            if clusterAssment[i,0] != minIndex:
                clusterChange = True
                clusterAssment[i,:] = minIndex,minDist**2
        #4.更新質心
        for j in range(k):
            # np.nonzero(x)傳回值不為零的元素的下标,它的傳回值是一個長度為x.ndim(x的軸數)的元組
            # 元組的每個元素都是一個整數數組,其值為非零元素的下标在對應軸上的值。
            # 矩陣名.A 代表将 矩陣轉化為array數組類型
            
            # 這裡取矩陣clusterAssment所有行的第一列,轉為一個array數組,與j(簇類标簽值)比較,傳回true or false
            # 通過np.nonzero産生一個array,其中是對應簇類所有的點的下标值(x個)
            # 再用這些下标值求出dataSet資料集中的對應行,儲存為pointsInCluster(x*4)
            pointsInCluster = dataSet[np.nonzero(clusterAssment[:,0].A == j)[0]]  # 擷取對應簇類所有的點(x*4)
            centroids[j,:] = np.mean(pointsInCluster,axis=0)   # 求均值,産生新的質心
            # axis=0,那麼輸出是1行4列,求的是pointsInCluster每一列的平均值,即axis是幾,那就表明哪一次元被壓縮成1
 
    print("cluster complete")
    return centroids,clusterAssment

def draw(data,center,assment):
    length=len(center)
    fig=plt.figure
    data1=data[np.nonzero(assment[:,0].A == 0)[0]]
    data2=data[np.nonzero(assment[:,0].A == 1)[0]]
    data3=data[np.nonzero(assment[:,0].A == 2)[0]]
    # 選取前兩個次元繪制原始資料的散點圖
    plt.scatter(data1[:,0],data1[:,1],c="red",marker='o',label='label0')
    plt.scatter(data2[:,0],data2[:,1],c="green", marker='*', label='label1')
    plt.scatter(data3[:,0],data3[:,1],c="blue", marker='+', label='label2')
    # 繪制簇的質心點
    for i in range(length):
        plt.annotate('center',xy=(center[i,0],center[i,1]),xytext=\
        (center[i,0]+1,center[i,1]+1),arrowprops=dict(facecolor='yellow'))
        #  plt.annotate('center',xy=(center[i,0],center[i,1]),xytext=\
        # (center[i,0]+1,center[i,1]+1),arrowprops=dict(facecolor='red'))
    plt.show()
    
    # 選取後兩個次元繪制原始資料的散點圖
    plt.scatter(data1[:,2],data1[:,3],c="red",marker='o',label='label0')
    plt.scatter(data2[:,2],data2[:,3],c="green", marker='*', label='label1')
    plt.scatter(data3[:,2],data3[:,3],c="blue", marker='+', label='label2')
    # 繪制簇的質心點
    for i in range(length):
        plt.annotate('center',xy=(center[i,2],center[i,3]),xytext=\
        (center[i,2]+1,center[i,3]+1),arrowprops=dict(facecolor='yellow'))
    plt.show()
    
    

dataSet = X
k = 3
centroids,clusterAssment = KMeans(dataSet,k)
draw(dataSet,centroids,clusterAssment)
           

輸出結果如下圖3所示:

【計算機視覺(CV)】基于k-means實作鸢尾花聚類【計算機視覺(CV)】基于k-means實作鸢尾花聚類前言一、鸢尾花資料集描述二、資料集預處理三、模型訓練四、可視化展示六、模型預測總結

六、模型預測

總結

本系列文章内容為根據清華社出版的《自然語言處理實踐》所作的相關筆記和感悟,其中代碼均為基于百度飛槳開發,若有任何侵權和不妥之處,請私信于我,定積極配合處理,看到必回!!!

繼續閱讀