天天看點

python 實作感覺機(perceptron)

文章目錄

    • 一、基本原理
    • 二、實作思路
    • 三、源代碼
    • 四、代碼運作結果

一、基本原理

感覺機(perceptron)是一個二類分類的線性分類模型,其幾何意義是尋找一個超平面将點(特征空間)劃分為正類和負類。本文以二維平面為例,實作一個簡單的感覺機模型。

二、實作思路

在二維平面中,感覺機的訓練過程即尋找一條直線,這條直線可以将平面中線性可分的點分離開,代碼實作思路如下:

  1. 生成訓練資料:

    為了保證資料是線性可分的,在生成資料前确定兩個點(如:(2, 2)、(6, 6)),在這兩個點的周圍随機生成資料,分别給這兩個點周圍的資料加上不同的标簽 -1 和 +1。

  2. 訓練模型(獲得直線參數)

    我們不妨設要尋找的直線方程為:w[0] * x + w[1] * y + b,初始化參數 w = [0., 1.] 和 b = 0,即直線的初始方程為 y = 0。

    接着從訓練集中取出一個點,将這個點帶入到目前訓練的直線中,如果求出的值和該點的标簽乘積小于等于零,說明直線沒有将這個點正确分類,這時更新 w[0],w[1] 和 b 的值。

    更新規則為:w[0] += 學習率 * 标簽值 * 點的橫坐标,w[1] += 學習率 * 标簽值 * 點的縱坐标,b += 學習率 * 标簽值。

    周遊訓練集的所有點,如果點沒有正确分類就按上述更新規則更新 w 和 b 的值,直到所有點都被正确分類為止。根據得到的 x, y 的參數 w[0], w[1] 和 b 的值,計算出直線的斜率和截距。

  3. 将訓練資料(點)和直線畫出

    根據步驟 2 得到的斜率和截距畫出直線,根據訓練集的點和點的标簽畫出點。

三、源代碼

"""
@description: perceptron
@author: Zhao Chengcheng
"""

import numpy as np
import matplotlib.pyplot as plt


def get_data(num):
    """
    @description: 随機生成資料
    @param num: 資料條數
    @return data: 點的坐标
    @return label: 每個點的标簽,為:-1 或 +1
    """
    data = [] # 存放随機生成的坐标 Xn
    label = [] # 存放資料的标簽, -1 或者 +1
    x1 = np.random.normal(2, 0.8, int(num / 2))
    y1 = np.random.normal(2, 0.8, int(num / 2)) # 在點 (2, 2) 周圍生成點
    x2 = np.random.normal(6, 0.8, int(num / 2))
    y2 = np.random.normal(6, 0.8, int(num / 2)) # 在點 (6, 6) 周圍生成點,保證生成的點是可被劃分的
    for i in range(num):
        if i < num / 2:
            data.append([x1[i], y1[i]])
            label.append(-1)
        else:
            data.append([x2[int(i - num / 2)], y2[int(i - num / 2)]])
            label.append(1)
    return data, label


def perceptron(data, label, eta):
    """
    訓練感覺機
    @param data: 包含坐标的資料
    @param label: 資料的标簽 -1 或者 +1
    @param eta: 學習率
    @return slope: 斜率
    @return intercept: 截距
    """
    w = [0., 1.0] # 直線 x 和 y 的系數
    b = 0.
    separated = False # 标記是否已将點完全分離
    while not separated:
        separated = True
        for i in range(len(data)):
            if label[i] * (w[0] * data[i][0] + w[1] * data[i][1] + b) <= 0:
                separated = False # 沒有完全分離
                w[0] += eta * label[i] * data[i][0] # 更新 w 的值
                w[1] += eta * label[i] * data[i][1]
                b += eta * label[i] # 更新 b 的值
    slope = -w[0] / w[1]    # 斜率
    intercept = -b / w[1]   # 截距
    return slope, intercept


def plot(data, label, slope, intercept):
    """
    @description: 畫出點和超平面(直線)
    @param data: 點的坐标
    @param label: 點的标簽
    @param slope: 直線的斜率
    @param intercept: 直線的縱截距
    """
    plt.rcParams['font.sans-serif'] = ['SimHei'] # 設定字型
    plt.rcParams['axes.unicode_minus'] = False
    plt.xlabel('X')
    plt.ylabel('Y')
    area = np.pi * 2 ** 2 # 點的面積

    data_mat = np.array(data)
    X = data_mat[:, 0]
    Y = data_mat[:, 1]
    for i in range(len(label)):
        if label[i] > 0:
            plt.scatter(X[i].tolist(), Y[i].tolist(), s=area, color='red')  # 畫點
        else:
            plt.scatter(X[i].tolist(), Y[i].tolist(), s=area, color='green')
    # 根據斜率和截距畫出直線
    axes = plt.gca()
    x_vals = np.array(axes.get_xlim())
    y_vals = intercept + slope * x_vals
    plt.plot(x_vals, y_vals)
    plt.show()


data, label = get_data(100) # 生成資料和标簽
slope, intercept = perceptron(data, label, 1) # 訓練模型,得到直線的斜率和截距
plot(data, label, slope, intercept) # 畫出點和直線
           

四、代碼運作結果

python 實作感覺機(perceptron)
python 實作感覺機(perceptron)

源碼位址:感覺機原始形式實作