天天看點

元學習之《On First-Order Meta-Learning Algorithms》論文詳細解讀

元學習系列文章

  1. optimization based meta-learning
    1. 《Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks》 論文翻譯筆記
    2. 元學習方向 optimization based meta learning 之 MAML論文詳細解讀
    3. MAML 源代碼解釋說明 (一)
    4. MAML 源代碼解釋說明 (二)
    5. 元學習之《On First-Order Meta-Learning Algorithms》論文詳細解讀:本篇部落格
    6. 元學習之《OPTIMIZATION AS A MODEL FOR FEW-SHOT LEARNING》論文詳細解讀
  2. metric based meta-learning: 待更新…
  3. model based meta-learning: 待更新…

文章目錄

      • 引言
      • On First-Order Meta-Learning Algorithms
        • 僞算法
        • 數學過程
        • 訓練過程
        • 實驗
        • 核心代碼
      • OpenAI Demo
      • 幾點思考
      • 參考資料

引言

上一篇部落格對論文 MAML 做了詳細解讀,MAML 是元學習方向 optimization based 的開篇之作,還有一篇和 MAML 很像的論文 On First-Order Meta-Learning Algorithms,該論文是大名鼎鼎的 OpenAI 的傑作,OpenAI 對 MAML 做了簡化,但效果卻優于 MAML,具體做了什麼簡化操作,請往下看😀。

On First-Order Meta-Learning Algorithms

這篇論文的标題就很針對 MAML,MAML 中有一個重要的特點,就是在求梯度時,為了加速放棄了二階求導,使用一階微分近似進行代替,雖然效果上相差不大,但總感覺少了點什麼。這篇論文的标題上來就聲稱我們是一階的 metalearning 方法,而且剛好是在 MAML 發表的下一年(2018)發表在 ICML 會議的,從标題上也是賺慢了噱頭。

還有個有意思的事情,OpenAI 把論文中的算法稱之為 Reptile, 但是也沒有解釋為什麼叫這個,論文中也沒看出來和 Reptile 有什麼關聯,感興趣的讀者,可以去深究一下。

說了一堆廢話,下面開始進入正題。

僞算法

貼一張論文中的官方算法:

元學習之《On First-Order Meta-Learning Algorithms》論文詳細解讀

先來解釋一下:

1 首先初始化一個網絡模型的所有參數 ϕ \phi ϕ

2 疊代 N 次,進行訓練,每次疊代執行:

  • 2.1 随機抽樣一個任務 T,用網絡模型進行訓練,對應的loss 是 L t L_t Lt​,訓練結束後的參數是 ϕ ~ \widetilde{\phi} ϕ

  • 2.2,在參數 ϕ \phi ϕ上使用 SGD 或 Adam 執行K次梯度下降更新,得到 ϕ ~ = U t k ( ϕ ) \widetilde{\phi}={U}^{k}_{t}(\phi) ϕ

    ​=Utk​(ϕ)

  • 2.3 用 ϕ ~ \widetilde{\phi} ϕ

    ​更新網絡模型模型參數, ϕ = ϕ + ϵ ( ϕ ~ − ϕ ) \phi=\phi+\epsilon(\widetilde{\phi}-\phi) ϕ=ϕ+ϵ(ϕ

    ​−ϕ)

3 完成上述N次疊代訓練,則結束整個過程

從上面的算法中可以看出,Reptile 是在每個單獨的任務執行K次訓練後,就開始真正更新網絡模型的參數(Meta),更新方式不是梯度下降,但是和梯度下降公式長得很像,是用上一次的參數 ϕ \phi ϕ和K次後的參數 ϕ ~ \widetilde{\phi} ϕ

​的差來更新,更新的步長是 ϵ \epsilon ϵ。在這個過程中,隻有一階求導的計算,就是在任務内部執行K次更新的過程中用到的随機梯度下降,這也是為什麼标題中叫

First-Order

的原因。

從這就可以看出和 MAML 算法的不同了:

  1. MAML:所有任務執行完,用每個任務測試集上的平均 loss 來更新 meta 參數。
  2. Reptile:每個任務執行K次訓練後,用最新的參數和 meta 參數的差來更新 meta 參數。

這裡說的meta參數,就是真正更新網絡模型參數的過程

數學過程

上面隻是簡單介紹了 Reptile 的算法思想,下面從數學過程上來了解下它的更新過程,先來設定幾個符号:

ϕ \phi ϕ代表網絡模型初始參數, ϵ , η \epsilon,\eta ϵ,η分别代表 meta 更新的學習率和 task 更新的學習率, N N N是meta訓練的 batch_size,即 meta 的一個bach有 N 個task,每個task内部執行K次訓練,N個任務都訓練完,再來更新meta參數。按照上面的算法過程,meta的一個batch訓練完之後,網絡模型的參數是:

ϕ = ϕ + ϵ 1 N ∑ i = 1 N ( ϕ i ~ − ϕ ) = ϕ + ϵ ( W − ϕ ) \begin{aligned} \phi &= \phi +\epsilon \frac{1}{N}\sum_{i=1}^{N}\left ( \tilde{\phi_i } -\phi\right )\\ &= \phi +\epsilon \left ( W-\phi \right )\\ \end{aligned} ϕ​=ϕ+ϵN1​i=1∑N​(ϕi​~​−ϕ)=ϕ+ϵ(W−ϕ)​

其中 W W W是每個任務最後參數的平均值,上述公式再進行展開就是這樣:

元學習之《On First-Order Meta-Learning Algorithms》論文詳細解讀

假設N=2,K=3,即meta每次訓練的一個batch 有2個task,每個task内部進行3此疊代,則 meta每次更新模型參數的公式為:

元學習之《On First-Order Meta-Learning Algorithms》論文詳細解讀

訓練過程

上面公式的最後一行,又變成了熟悉的梯度下降,隻不過梯度方向是每個任務内部更新的幾次梯度方向的和。meta 模型的參數更新過程,在幾何上就是這樣的:

元學習之《On First-Order Meta-Learning Algorithms》論文詳細解讀

動圖看的更加清晰些,其中綠色代表第一個任務,三個綠色箭頭代表三次更新時的梯度方向,可以看到,Reptile的模型就是朝着每個任務的梯度和的方向上不斷地進行更新。

還記得 MAML 是怎樣更新的嗎?不記得的話,請翻看上一篇部落格。還是同樣的設定,MAML 的更新過程如下:

元學習之《On First-Order Meta-Learning Algorithms》論文詳細解讀

即 MAML 是在每個任務最後一個梯度的方向上進行更新,而 Reptile 是在每個任務幾個梯度和的方向上進行更新。

實驗

實驗設定和 MAML 論文中的設定一樣,回歸任務以拟合正弦函數為例,分類任務以 MiniImagenet 資料和 omniglot 資料的圖檔分類為例,詳細設定就不再贅述了,直接看實驗結果:

元學習之《On First-Order Meta-Learning Algorithms》論文詳細解讀

上半部分的圖是正弦函數的拟合結果,

(b)

是MAML的結果,

C

是Reptile的結果,橘黃色線是微調32次之後的樣子,綠色線是真實分布,可以看到 Reptile和MAML的結果相當,都能拟合到真實分布的樣子,硬要一較高下的話,那就是 Reptile稍好一些。

下半部分圖是在 MiniImagenet 分類資料上的結果,作者也對比了一階近似 MAML和二階MAML的結果,從圖中可以看出,Reptile的準确率至少要高出1個百分點。

在論文中作者還對比了一個有意思的實驗,Reptile 既然可以在 g 1 + g 2 + g 3 g_1+g_2+g_3 g1​+g2​+g3​ 的梯度方向上更新,那麼如果在其它梯度的組合方向上去更新,結果會怎樣呢?比如 g 1 + g 2 g_1+g_2 g1​+g2​ 等方向,作者也針對不同梯度的組合進行了實驗,實驗結果如下:

元學習之《On First-Order Meta-Learning Algorithms》論文詳細解讀

橫軸是meta疊代次數,縱軸是準确率,不同顔色的曲線代表不同的梯度組合,可以明顯的看到最下面的藍色曲線準确率最低,藍色曲線代表在 g 1 g_1 g1​ 第一個梯度方向上去更新,其實就是模型預訓練的過程,以所有訓練任務的 loss 為準進行更新。其他顔色的曲線都代表用若幹次之後的 loss 來更新參數,最上面的那條曲線代表 Reptile,即用 g 1 + g 2 + g 3 + g 4 g_1+g_2+g_3+g_4 g1​+g2​+g3​+g4​ 的梯度方向進行更新,隻使用 g 4 g_4 g4​ 的那條曲線代表 MAML。

核心代碼

Reptile 的論文代碼也是開源的,而且代碼很簡介規範,不愧是 OpenAI 出品。建議感興趣的讀者去看下論文源碼,不僅能更好的了解論文思想,對工程能力的提升也很有幫助,包括代碼風格、子產品化、組織架構、邏輯實作等都有很多值得借鑒的地方。關于源代碼有疑問的話,可以私信聯系我。這裡隻貼一點核心的訓練更新代碼,對應上面的數學過程:

代碼檔案見 reptile.py

# 取出網絡模型的最新參數
        old_vars = self._model_state.export_variables()
        # 儲存一個 meta batch 裡,每個 task 更新 K 次後的參數
        new_vars = []
        for _ in range(meta_batch_size):
            # 抽樣出一個 task
            mini_dataset = _sample_mini_dataset(dataset, num_classes, num_shots)
            for batch in _mini_batches(mini_dataset, inner_batch_size, inner_iters, replacement):
                # task 裡面的訓練,更新 inner_iters 次,相當于公式中的K
                inputs, labels = zip(*batch) # inner_iters 個 batch,每個 iter 使用一個 batch ,裡面的一次訓練疊代
                if self._pre_step_op:
                    self.session.run(self._pre_step_op)
                self.session.run(minimize_op, feed_dict={input_ph: inputs, label_ph: labels})
            # 一個 task 内部訓練完的參數
            new_vars.append(self._model_state.export_variables())
            self._model_state.import_variables(old_vars)
        # 對 meta_batch 個 task 的最終參數進行平均,相當于公式中的 W
        new_vars = average_vars(new_vars)
        # 所有的 meta_batch 個任務都訓練完, 更新一次 meta 參數,并且把更新後的參數更新到計算圖中,下次訓練從最新參數開始
        # 更新方式:old + scale*(new - old)
        self._model_state.import_variables(interpolate_vars(old_vars, new_vars, meta_step_size))
           

OpenAI Demo

在 OpenAI 的官方部落格 Reptile: A Scalable Meta-Learning Algorithm中,也有介紹這篇論文。該部落格網頁中還有個有意思的 demo,大家可以試玩一下:

元學習之《On First-Order Meta-Learning Algorithms》論文詳細解讀

這個 demo 的意思是,openAI 已經用他們的 Reptile 算法訓練了一個用于少樣本場景的3分類網絡模型,并且嵌入到了網頁中,使用者可以通過 demo 中的互動制作一個新的三分類任務,并且這個任務隻有三個訓練樣本,也就是每個類下隻有一個樣本,學名叫3-Way 1-shot,讓他們的模型在這三個樣本上進行微調學習,然後在右邊畫一個新的三個類别下的測試樣本,Reptile 模型會自動給出它在三個類别下的機率。通過這個 demo 來證明他們的模型确實有奇效,在新任務的幾個樣本上微調一下,就可以在該任務的測試集上取得很好的準确率。

幾點思考

通過上面的 demo 可以得出一些結論:

  1. 畫圖框是固定尺寸,而且是黑白圖案,相當于輸入大小是固定的,是以可以用同一個模型進行訓練
  2. 框裡面可以任意畫一些圖案,比如畫數字

    1,2,3

    的圖案,那就變成了少樣本手寫數字識别任務;畫

    A,B,C

    的圖案,那就變成了手寫字母識别;畫三個貓、狗、兔子的圖案,那就變成了動物識别;這樣是不是說明了,通過 meta-learning 的方法預訓練網絡模型,可以在視覺場景中有廣泛應用 ?因為隻要輸入圖檔的尺寸是固定的,就可以一個模型應對所有任務。不知道這樣想是不是對的,如果是的話,那感覺看到了一個巨大的商機。
  3. Reptile 的方法能不能用到傳統的結構化資料上進行遷移 ?這就涉及到對 task 定義以及 task 間相似性的了解了,歡迎感興趣的讀者一起交流。

參考資料

  • https://arxiv.org/pdf/1803.02999.pdf
  • https://github.com/openai/supervised-reptile
  • https://www.bilibili.com/video/BV1Gb411n7dE?p=32

繼續閱讀