天天看點

mse優化方法(四)——adarm

import numpy as np
# 導入動畫包
import matplotlib.animation as animation
data = np.array([
    [80,200],
    [95,230],
    [104,245],
    [112,247],
    [125,259],
     [135,262]
])

# 兩個數組記錄m和b的變化過程
mhistroy=[]
bhistroy=[]
# 記錄mse的變化過程
msehistory=[]

Weight  =np.ones((2,1)) # m和b 采用矩陣的方式指定權重
ones = np.ones((len(data),1))
Feature = np.hstack((data[:,0:1],ones))
label = data[:,1:2]

learningrate = 0.1

m = np.zeros((2,1)) # 記錄的mse對m和b變化率的慣性
v = np.zeros((2,1)) # 記錄的mse對m和b變化率的速度

def grandentdecent3():
    global Weight,m,v,learningrate
    # 計算mse
    mse = np.sum(np.square(np.dot(Feature,Weight)-label))
    msehistory.append(mse)
    # 計算slop
    slop = np.dot(Feature.T,(np.dot(Feature,Weight)-label))
    ## adam的核心邏輯
    beta_1 = 0.9
    beta_2 = 0.999
    
    m = beta_1*m +(1-beta_1)*slop
    v = beta_2*v  +(1-beta_2)*(slop**2)
    m_p = m/(1-beta_1)
    v_p = v/(1-beta_2)
    
    Weight = Weight - learningrate*m_p/np.sqrt(v_p+0.000000001)
    
    mhistroy.append(Weight[0][0])
    bhistroy.append(Weight[1][0])
           
for i in range(50000):
    grandentdecent7()  
           
## 以動畫的方式展示m和b收斂的過程

%matplotlib notebook
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(6,6),dpi=60)
plt.xlim(0,5)
plt.ylim(0,130)

axis_name, =  plt.plot(mhistroy[0:100],bhistroy[0:100],c='r')

plt.annotate("goal",xy=(1.0859,122.68), xytext=(+10, +15),
             textcoords='offset points', fontsize=12,
             arrowprops=dict(arrowstyle="->"))

def update(num):
    axis_name.set_data(mhistroy[0:num*100],bhistroy[0:num*100])

animation.FuncAnimation(fig,update,np.arange(0,501),interval=20,repeat=False)
           
mse優化方法(四)——adarm

繼續閱讀