天天看點

Optimizing Federated Learning on Non-IID Data with Reinforcement Learning 筆記

  • 閱讀論文 Optimizing Federated Learning on Non-IID Data with Reinforcement Learning 的筆記
  • 如有侵權,請聯系作者,将會撤銷釋出。

主要講什麼

  • 提出FAVOR,一個經驗驅動控制的架構。
  • 智能的選擇用戶端裝置來參與聯邦學習中每一輪訓練,以抵消資料非獨立同分布帶啊來的偏差,并提升收斂的速度。
  • 使用了deep Q-learning 來學習如何選擇每輪參與訓練的用戶端以最大化一個 鼓勵提升正确率并處罰使用更多通信次數的 獎勵。

Intro

  • 一般聯邦學習都是直接随機選取一部分裝置參與每輪的訓練,以避免由于不穩定的網絡狀況和straggler裝置造成的長尾(long-tailed)等待時間
  • FedAvg可能會嚴重的降低模型的準确性和收斂所需的通信次數
    • 而且由于資料非獨立同分布,聚合這些不同的模型可能會減慢收斂,并且會降低模型準确性
    • 一個裝置中的訓練資料的分布和訓練得到的模型參數之間有内含的聯系

這篇文章提出的目标

FAVOR的目标

  • 通過學習積極地在每輪選擇最好的,可以抵消非獨立同分布會帶來的偏差的裝置集,以加速并穩定聯邦學習過程。

選擇裝置

  • 用本地模型參數和共享的全局模型作為狀态,進而公平地?選擇可能對全局模型有所提升的裝置
  • 使用基于DQN的強化學習來提高效率和魯棒性。(在FL的裝置選擇環節中使用基于DQN的強化學習)

壓縮模型參數

  • 提出了一個可以壓縮模型參數以對狀态空間降維
  • apply principle component analysis(PCA) to model weights and use the compressed model weights to represent states instead.
  • 隻根據在第一輪訓練(step 2中得到的)的本地模型的參數來計算PCA
  • # TODO: 看不懂源碼,看不懂過程QAQ

非獨立同分布的挑戰

  • 論文中用實驗來展現:
    • 如果随機選取裝置,那麼非獨立同分布的資料可能會減慢聯邦學習的收斂速度。
    • 用cluster 算法可以有助于平衡資料分布并加快收斂。

實驗過程

  1. 100個裝置下載下傳最初的Global weights(随機生成的)然後根據本地資料執行一個epoch的SGD,獲得\(w_1^{(k)}\)
  2. 對\(w_1^{(k)}\)執行K-Center算法,對100個裝置進行聚類,分成了10個組。
  3. 在每個組裡面随機選擇一個裝置進行聯邦學習。
  • 結果:
    Optimizing Federated Learning on Non-IID Data with Reinforcement Learning 筆記
  • 這個實驗說明了:通過仔細選擇每輪參與訓練的裝置可以提高聯邦學習的性能。

用DRL來選擇用戶端

Agent 基于 Deep Q-Network

  • 用DQN來選擇k個最合适的裝置來參與訓練
  • 通過一個網絡來學習得到\(Q^*(s_t,a)\),選擇\(Q^*\)最大的k個裝置來訓練。
  • 因為裝置中資料非獨立同分布的原因,直接随機選擇裝置來訓練效果會不好,是以用這個DQN可以根據每個裝置中的模型參數來訓練,得到一個選擇裝置的政策。
  • \(s_t=(w_t,w_t^{(1)},...,w_t^{(N)})\)
  • \(a\): action space為{1,2,...,N}, a=1指選擇裝置i去參與FL訓練
  • DQN agent 被訓練為要最大化cumulative discounted reward (即R) 的期望。:
    • reward: \(r_t=\Xi ^{(w_t-\Omega)}-1\)
      • \(w_t\): 在第 t 輪結束後,對held-out validation set(保留驗證集)上的測試得出的準确度
      • \(\Omega\): 目标準确度
      • \(\Xi ^{(w_t-\Omega)}\): 激勵agent去選擇能取得更高準确度\(w_t\)的裝置
        • 由于通常随着在機器學習進行,模型準确度的增長速度會變慢,也就是随着t增加,\(|w_t-w_{t+1}|\)會減小。
        • 是以用這樣的指數項來放大FL過程靠後階段中微小的準确度的增長。
        • \(\Xi\): 一個正常數,論文中的實驗設定為了64
      • -1:激勵 agent 用更少的訓練輪數 (?)
    • \(R=\sum_{t=1}^{T}\gamma ^{t-1}r_t\)
    • 當\(w_t=\Omega, r_t == 0\) 時,聯邦學習結束

FAVOR過程

  1. N個可行的裝置向FL server報到
    1. 每個裝置都從server上下載下傳最初的随機獲得的模型參數\(w_{init}\)
    2. 用 local SGD 訓練一個epoch,然後将訓練得到的模型參數\(\{w_1^{(k)},k \in [N]\}\)傳給FL server
    1. 接收到上傳的local weights後,對應在server上存的local weights更新
    2. DQN agent 計算所有裝置的\(Q(s_t,a;\theta)\)
    1. DQN agent 根據\(Q(s_t,a;\theta)\)的大小,選擇k個最大Q值對應的k個裝置。
    2. 被選中的k個裝置下載下傳最新的global model weights \(w_t\), 并執行一個epoch的local SGD以獲得\(\{w_{t+1}^{(k)}|k \in [K]\}\)
  2. \(\{w_{t+1}^{(k)}|k \in [K]\}\)被傳到server,以使用FEDAVG計算\(w_{t+1}\)。重複3-5步直到結束(如達到目标準确率,或者 訓練了一定數量的rounds)。
  • 論文作者GitHub上還沒有給出這部分的代碼。

用PCA降維

  • 對模型參數使用PCA,然後用壓縮後的模型參數來表示states。
  • 看不懂這部分代碼

用Double DQN 訓練Agent

  • 使用DDQN來學習函數\(Q^*(s_t,a)\)
  • 原來的Q-Learning可能會不穩定
  • 而DDQN加入了另一個value function \(Q(s,a;\theta_t\')\),這樣可以使action-value函數的估計更加穩定