- 論文位址: https://arxiv.org/pdf/1905.13611.pdf
KDD 2019 | 不用反向傳播就能訓練DL模型,ADMM效果可超梯度下降 - 代碼位址: https://github.com/xianggebenben/dlADMM
論文:ADMM for Efficient Deep Learning with Global Convergence
本文提出了一種基于交替方向乘子法的深度學習優化算法 dlADMM。該方法可以避免随機梯度下降算法的梯度消失和病态條件等問題,彌補了此前工作的不足。此外,該研究提出了先後向再前向的疊代次序加快了算法的收斂速度,并且對于大部分子問題采用二次近似的方式進行求解,避免了矩陣求逆的耗時操作。在基準資料集的實驗結果表明,dlADMM 擊敗了大部分現有的優化算法,進一步證明了它的有效性和高效。
背景
深度學習已經在機器學習的各個領域受到廣泛的應用,因為 深度學習模型可以表征非線性特征的多層嵌套組合,是以相比傳統的機器學習模型,它的表達性更豐富。由于深度學習通常用在大資料的應用場景中,是以需要一種優化算法可以在有限的時間内得到一個可用的解。随機梯度下降算法 (SGD) 和它的許多變體 如 ADAM 是深度學習領域廣泛使用的優化算法,但是它存在着如梯度消失 (gradient vanishing) 和病态條件 (poor conditioning) 等問題;另一方面, 作為近年非常熱門的優化架構,交替方向乘子算法 (ADMM) 可以解決 SGD 存在的問題: ADMM 的基本原理是把一個複雜的複合目标函數分解成若幹個簡單的子函數求解,這樣不需要用鍊式法則求複合函數的導數,進而避免了梯度消失的問題,另外 ADMM 對輸入不敏感,是以不存在病态條件的問題 [1]。除此之外 ADMM 還有諸多優點:它 可以解決非光滑函數的優化問題;它在很多大規模的深度學習應用中展現了強大的可擴充性 (scalability) 等等。
經典反向傳播算法 (BP)
一個典型的神經網絡問題如下所示:
其中 W_l 和 b_l 是第 l 層的權重和截距, f_l (∙) 和 R(∙) 分别是激活函數和損失函數, L 是層數。經典的反向傳播算法分前向傳送梯度和後向更新參數兩部分:每層的梯度按照鍊式法則向前傳輸,然後根據損失函數反向參數更新如下:
盡管反向傳播非常實用,然而它存在一些問題,因為前面層級的梯度必須等後面層級的梯度算完,BP 的速度比較慢;因為 BP 需要儲存前向傳播的激活值,是以顯存占用比較高;最常見的問題就是對于深度神經網絡存在梯度消失, 這是因為根據鍊式法則
如果
, 那麼随着層數的增加,梯度信号發生衰減直至消失。連提出 BP 的 Geoffrey Hinton 都對它充滿了質疑。主要是因為反向傳播機制實在是不像大腦。Hinton 在 2017 年的時候提出膠囊理論嘗試替代标準神經網絡。但是隻要反向傳播的機制不改變,梯度消失的問題就不會解決。是以目前有許多研究者嘗試采用各種機制訓練神經網絡,早在 2016 年,Gavin Taylor[1] 等人就提出了 ADMM 的替代想法,ADMM 的原理如圖 1 所示,一個神經網絡按照不同的層被分解成若幹個子問題,每個子問題可以并行求解,這樣不需要求複合目标函數 R(∙) 的導數,解決了梯度下降的問題。按照 ADMM 的思路,神經網絡問題被等價轉換為如下問題 1:
問題 1:
其中 a_l 是輔助變量。
圖 1. 基于 ADMM 的深度神經網絡子問題分解示意圖
挑戰
盡管 ADMM 具備很多優點,但是把它應用在深度學習的問題中的效果較目前最優算法入 SGD 和 ADAM 還有很大差距,很多技術和理論問題仍亟待解決:1) 收斂慢。即使最簡單的目标函數,通常 ADMM 需要很多次疊代才能達到最優解。2) 對于特征次元的三次時間複雜度。在 Taylor 等人的實驗中,他們使用了超過 9000 核 CPU 來讓 ADMM 訓練了僅僅 300 個神經元 [1]。其中 ADMM 最耗時的地方在于求解逆矩陣, 它的時間複雜度大概在 O(n^3 ), 其中 n 是神經元的個數。3) 缺乏收斂保證。盡管很多實驗證明了 ADMM 在深度學習中是收斂的,然而它的理論收斂行為依然未知。主要原因是因為神經網絡是線性和非線性映射群組合體,因而是高度非凸優化問題。基于這些問題,最新的一期 KDD2019 論文提出了 ADMM 的改進版本 dlADMM,第一次使基于 ADMM 的算法在多個标準資料集上達到目前最佳效果,并且在收斂性理論證明得到重要突破。
dlADMM 相比 ADMM 的優勢:
- 加快收斂。文章提出了一種新的疊代方式加強了訓練參數的資訊交換,進而加快了 dlADMM 的收斂過程。
- 加快運作速度。作者通過二次近似的技術避免了求解逆矩陣,把時間複雜度從 O(n^3 ) 降低到 O(n^2 ),即與梯度下降相同的複雜度。進而大幅提高 ADMM 的運作速度。
- 具備收斂保證。本文第一次證明了 dlADMM 可以全局收斂到問題的一個駐點(該點導數為 0)。
下面具體讨論提出算法的這些優勢:
1). 加快收斂
直接用 ADMM 解問題 1 并不能保證收斂,是以作者把問題 1 放松成如下的問題 2。在問題 2 中,當ν→+∞ 時,問題 2 無線逼近 問題 1。
問題 2:
在問題 2 中,ν>0 是一個參數, z_l 是一個輔助變量。在解問題 2 的過程中,通過增大ν可以使其理論上無限逼近問題 1。
問題 2 的增廣拉格朗日 [2] 形式如下:
其中
是 dlADMM 中的一個超參數。
為了加快收斂過程,作者提出了一種新的疊代方式:先後向更新再前向更新,如圖 2 所示。具體來講,參數從最後一層開始更新,然後向前更新直到第一層,接着參數從第一層開始向後更新直到最後最後一層。這樣更新的好處在于最後一層的參數資訊可以層層傳遞到第一層,而第一層的參數資訊可以層層傳遞到最後一層,加強參數資訊交換,進而幫助參數更快地收斂。
圖 2. dlADMM 原理圖
2). 加快運作速度
對于求解 dlADMM 産生的子問題,大部分都需要耗時的矩陣求逆操作。為此,作者使用了二次近似的技術,如圖 3 所示。在每一次疊代的時候對目标函數做二次近似函數展開,由于變量的二次項是一個常數,是以不需要求解逆矩陣,進而提高了算法的運作效率。
圖 3. 二次近似
3.) 收斂保證
作者證明了無論參數 (W,b,z,a) 如何初始化,當ρ 足夠大的時候,dlADMM 全局收斂于問題 2 的一個駐點上。
具體來說,是基于如下兩條假設:
a. 求解 z 的子問題存在顯式解。b. F 是強制的 (coercive),R(∙) 是萊布尼茨可導 (Lipschitz differentiable)。
對于假設 a, 常用的激活函數如 Relu 和 Leaky Relu 滿足條件;對于假設 b,常用的交叉熵和最小二乘損失函數都滿足條件。
在此基礎之上, 本文證明了三條收斂性質:
- 是有界的,L_ρ是有下界的。
KDD 2019 | 不用反向傳播就能訓練DL模型,ADMM效果可超梯度下降 - L_ρ是單調下降的。
- L_ρ的次梯度趨向于 0。
同時文章也證明了 dlADMM 的收斂率是 o(1/k).
實驗結果
該論文在兩個基準資料集 MNIST 和 Fashion MNIST 上進行了實驗。
1. 收斂驗證
作者畫出了當ρ=1 和 ρ=〖10〗^(-6) 的收斂曲線,驗證了當 ρ 足夠大的時候,dlADMM 是收斂的(Figure 2),反之,dlADMM 不停地振蕩(Figure 3)。
2. 效果比較
作者把 dlADMM 和目前公認的算法進行了比較。比較的方法包括:a. 随機梯度下降 (SGD). b. 自适應性梯度算法 (Adagrad). c. 自适應性學習率算法 (Adadelta).d. 自适應動量估計 (Adam). e. 交替方向乘子算法 (ADMM).Figure 4 和 Figure 5 展示了所有算法在 MNIST 和 Fashion MNIST 的訓練集和測試集的正确率,可以看到開始的時候 dlADMM 上升最快,并且在二十次疊代之内迅速達到非常高的精度并且擊敗所有算法。在 80 次疊代之後雖被 ADAM 反超,但是仍然優于其他算法。文中提出的改進版 dlADMM 顯著提高了 ADMM 的表現,使之充分發揮 ADMM 在疊代初期進展迅速的優點,同時在後期保證較高收斂率。
3. 時間分析
文章作者分析了 dlADMM 的運作時間和資料集數量,神經元個數以及和 ρ的選取上面的關系。如 table 2 和 table 3 所示,當資料集數量, 神經元個數和ρ的值越大的時候, dlADMM 的運作時間越長。具體來說, 運作時間和神經元個數和資料集數量成線性關系。這個結果展現了 dlADMM 強大的可擴充性 ( scalability)。
本文提出的 dlADMM 代碼已經公開, 歡迎使用。連結如下:
。歡迎郵件聯系 [email protected] 或者 [email protected]。
文獻:
[1]. Taylor, Gavin, et al. "Training neural networks without gradients: A scalable admm approach". International conference on machine learning. 2016.[2]. Boyd, Stephen, et al. "Distributed optimization and statistical learning via the alternating direction method of multipliers." Foundations and Trends® in Machine learning 3.1 (2011): 1-122.