文章目錄
-
- 一、基本原理
- 二、實作思路
- 三、源代碼
- 四、代碼運作結果
一、基本原理
感覺機(perceptron)是一個二類分類的線性分類模型,其幾何意義是尋找一個超平面将點(特征空間)劃分為正類和負類。本文以二維平面為例,實作一個簡單的感覺機模型。
二、實作思路
在二維平面中,感覺機的訓練過程即尋找一條直線,這條直線可以将平面中線性可分的點分離開,代碼實作思路如下:
-
生成訓練資料:
為了保證資料是線性可分的,在生成資料前确定兩個點(如:(2, 2)、(6, 6)),在這兩個點的周圍随機生成資料,分别給這兩個點周圍的資料加上不同的标簽 -1 和 +1。
-
訓練模型(獲得直線參數)
我們不妨設要尋找的直線方程為:w[0] * x + w[1] * y + b,初始化參數 w = [0., 1.] 和 b = 0,即直線的初始方程為 y = 0。
接着從訓練集中取出一個點,将這個點帶入到目前訓練的直線中,如果求出的值和該點的标簽乘積小于等于零,說明直線沒有将這個點正确分類,這時更新 w[0],w[1] 和 b 的值。
更新規則為:w[0] += 學習率 * 标簽值 * 點的橫坐标,w[1] += 學習率 * 标簽值 * 點的縱坐标,b += 學習率 * 标簽值。
周遊訓練集的所有點,如果點沒有正确分類就按上述更新規則更新 w 和 b 的值,直到所有點都被正确分類為止。根據得到的 x, y 的參數 w[0], w[1] 和 b 的值,計算出直線的斜率和截距。
-
将訓練資料(點)和直線畫出
根據步驟 2 得到的斜率和截距畫出直線,根據訓練集的點和點的标簽畫出點。
三、源代碼
"""
@description: perceptron
@author: Zhao Chengcheng
"""
import numpy as np
import matplotlib.pyplot as plt
def get_data(num):
"""
@description: 随機生成資料
@param num: 資料條數
@return data: 點的坐标
@return label: 每個點的标簽,為:-1 或 +1
"""
data = [] # 存放随機生成的坐标 Xn
label = [] # 存放資料的标簽, -1 或者 +1
x1 = np.random.normal(2, 0.8, int(num / 2))
y1 = np.random.normal(2, 0.8, int(num / 2)) # 在點 (2, 2) 周圍生成點
x2 = np.random.normal(6, 0.8, int(num / 2))
y2 = np.random.normal(6, 0.8, int(num / 2)) # 在點 (6, 6) 周圍生成點,保證生成的點是可被劃分的
for i in range(num):
if i < num / 2:
data.append([x1[i], y1[i]])
label.append(-1)
else:
data.append([x2[int(i - num / 2)], y2[int(i - num / 2)]])
label.append(1)
return data, label
def perceptron(data, label, eta):
"""
訓練感覺機
@param data: 包含坐标的資料
@param label: 資料的标簽 -1 或者 +1
@param eta: 學習率
@return slope: 斜率
@return intercept: 截距
"""
w = [0., 1.0] # 直線 x 和 y 的系數
b = 0.
separated = False # 标記是否已将點完全分離
while not separated:
separated = True
for i in range(len(data)):
if label[i] * (w[0] * data[i][0] + w[1] * data[i][1] + b) <= 0:
separated = False # 沒有完全分離
w[0] += eta * label[i] * data[i][0] # 更新 w 的值
w[1] += eta * label[i] * data[i][1]
b += eta * label[i] # 更新 b 的值
slope = -w[0] / w[1] # 斜率
intercept = -b / w[1] # 截距
return slope, intercept
def plot(data, label, slope, intercept):
"""
@description: 畫出點和超平面(直線)
@param data: 點的坐标
@param label: 點的标簽
@param slope: 直線的斜率
@param intercept: 直線的縱截距
"""
plt.rcParams['font.sans-serif'] = ['SimHei'] # 設定字型
plt.rcParams['axes.unicode_minus'] = False
plt.xlabel('X')
plt.ylabel('Y')
area = np.pi * 2 ** 2 # 點的面積
data_mat = np.array(data)
X = data_mat[:, 0]
Y = data_mat[:, 1]
for i in range(len(label)):
if label[i] > 0:
plt.scatter(X[i].tolist(), Y[i].tolist(), s=area, color='red') # 畫點
else:
plt.scatter(X[i].tolist(), Y[i].tolist(), s=area, color='green')
# 根據斜率和截距畫出直線
axes = plt.gca()
x_vals = np.array(axes.get_xlim())
y_vals = intercept + slope * x_vals
plt.plot(x_vals, y_vals)
plt.show()
data, label = get_data(100) # 生成資料和标簽
slope, intercept = perceptron(data, label, 1) # 訓練模型,得到直線的斜率和截距
plot(data, label, slope, intercept) # 畫出點和直線
四、代碼運作結果
源碼位址:感覺機原始形式實作