天天看點

不依賴Python第三方庫實作梯度下降

認識

梯度的本意是一個向量(矢量),表示某一函數在該點處的方向導數沿着該方向取得最大值,即函數在該點處沿着該方向(此梯度的方向)變化最快,變化率最大(為該梯度的模), 我感覺, 其實就是偏導數向量方向呗, 沿着這個向量方向可以找到局部的極值.

from random import random

def gradient_down(func, part_df_func, var_num, rate=0.1, max_iter=10000, tolerance=1e-10):
    """
    不依賴第三庫實作梯度下降
    :param func: 損失(誤差)函數
    :param part_df_func: 損失函數的偏導數向量
    :param var_num: 變量個數
    :param rate: 學習率(參數的每次變化的幅度)
    :param max_iter: 最大計算次數
    :param tolerance: 誤差的精度
    :return: theta, y_current:  權重參數值清單, 損失函數最小值
    """

    theta = [random() for _ in range(var_num)]  # 随機給定參數的初始值
    y_current = func(*theta)  # 參數解包

    for i in range(max_iter):
        # 計算目前參數的梯度(偏導數導數向量值)
        gradient = [f(*theta) for f in part_df_func]
        # 根據梯度更新參數 theta
        for j in range(var_num):

            theta[j] -= gradient[j] * rate  # [0.3, 0.6, 0.7] ==> [0.3-0.3*lr, 0.6-0.6*lr, 0.7-0.7*lr]

            y_current,  y_predict = func(*theta), y_current
            print(f"正在進行第{i}次疊代, 誤差精度為{abs(y_predict - y_current)}")

            if abs(y_predict - y_current) < tolerance:   # 判斷是否收斂, (誤內插補點的精度)

                print(); print(f"ok, 在第{i}次疊代, 收斂到可以了哦!")

                return theta, y_current


def f(x, y):
    """原函數"""
    return (x + y - 3) ** 2 + (x + 2 * y - 5) ** 2 + 2


def df_dx(x, y):
    """對x求偏導數"""
    return 2 * (x + y - 3) + 2 * (x + 2 * y - 5)


def df_dy(x, y):
    """對y求偏導數, 注意求導的鍊式法則哦"""
    return 2 * (x + y - 3) + 2 * (x + 2 * y - 5) * 2


def main():
    """主函數"""
    print("用梯度下降的方式求解函數的最小值哦:")
    theta, f_theta = gradient_down(f, [df_dx, df_dy], var_num=2)

    theta, f_theta = [round(i, 3) for i in theta], round(f_theta, 3)  # 保留3位小數

    print("該函數最優解是: 當theta取:{}時,f(theta)取到最小值:{}".format(theta, f_theta))


if __name__ == '__main__':
    main()      
...
...
正在進行第248次疊代, 誤差精度為1.6640999689343516e-10
正在進行第249次疊代, 誤差精度為1.5684031851037616e-10
正在進行第250次疊代, 誤差精度為1.478208666583214e-10
正在進行第251次疊代, 誤差精度為1.3931966691416164e-10
正在進行第252次疊代, 誤差精度為1.3130829756846651e-10
正在進行第253次疊代, 誤差精度為1.2375700464417605e-10
正在進行第254次疊代, 誤差精度為1.166395868779091e-10
正在進行第255次疊代, 誤差精度為1.0993206345233375e-10
正在進行第256次疊代, 誤差精度為1.0361000946090826e-10
正在進行第257次疊代, 誤差精度為9.765166453234997e-11

ok, 在第257次疊代, 收斂到可以了哦!
該函數最優解是: 當theta取:[1.0, 2.0]時,f(theta)取到最小值:2.0
[Finished in 0.0s]