天天看點

DARTS+:DARTS 搜尋為何需要早停?

論文位址:

https://www.weiranhuang.com/publications/DARTS+.pdf

DARTS+ 在原始 DARTS 算法基礎上隻需簡單地加入一條早停機制,就可以在 CIFAR10、CIFAR100 和 ImageNet 上取得 2.32%、14.87% 和 23.7% 的錯誤率,超越一系列現有的 DARTS 改進算法,包括 SNAS[2]、P-DARTS[3]、XNAS[4]、PC-DARTS[5] 等。

在模型大小相當的情況下,DARTS+ 可以達到與谷歌提出的 EfficientNet[6] 相同的性能,但是搜尋時間卻遠遠小于 EfficientNet,再疊加上一些常用的 tricks,在 ImageNet 上可以達到 22.5% 的錯誤率!早停機制的引入,讓原本在搜尋時間上具有顯著優勢的基于「可微分」的架構搜尋方法,在性能上也開始超越基于「強化學習」或「演化算法」的架構搜尋方法,極大地增加了「可微分架構搜尋」的研究價值和應用範圍。

簡介

神經網絡架構搜尋(Neural Architecture Search,NAS)在自動機器學習(AutoML)中扮演着重要的角色,近來獲得越來越多的關注。用 NAS 搜尋得到的神經網絡架構已經在多種任務上超越了專家手工設計的網絡架構,包括物體分類、物體檢測、推薦系統等。

神經網絡架構搜尋的常見做法是首先設計一個架構搜尋空間,然後用某種搜尋政策,從中找出一個最優的網絡架構。早期的方案是基于強化學習(RL)或者演化算法(Evolutionary Algorithm)來搜尋一個有效的網絡架構,但是會耗費大量的計算資源(上千個 GPU days),不經濟也不環保。後來,一些 One-Shot 的方案相繼被提出,其中最具代表性的是 DARTS[1] 算法(Differentiable Architecture Search,可微分的神經網絡架構搜尋)。它把搜尋空間從離散的放松到連續的,進而能夠用梯度下降來同時搜尋架構和學習權重。具體來說,DARTS 使用了如下的兩層優化(Bi-Level Optimization)來搜尋:

DARTS+:DARTS 搜尋為何需要早停?

Bi-Level Optimization in DARTS

其中,alpha 是架構的參數,w 是 alpha 對應的模型權重。前者利用 validation data 來進行更新,後者利用 training data 來進行更新。具體細節可以參看 DARTS 的原文。DARTS 成功把搜尋時間從上千個 GPU days 減少到了幾個 GPU days。

DARTS 算法的問題

DARTS 算法有一個嚴重的問題,就是當搜尋輪數過大時,搜尋出的架構中會包含很多的 skip-connect,進而性能會變得很差。我們把這個現象叫做 Collapse of DARTS。

舉個例子,讓我們來考慮在 CIFAR100 上用 DARTS 做搜尋。從下圖可以看出,當 search epoch(橫軸)比較大的時候,skip-connect 的 alpha 值(綠線)将變得很大。

DARTS+:DARTS 搜尋為何需要早停?

Alpha Values in The Shallowest Edge

是以,在 DARTS 最後選出的網絡架構中,skip-connect 的數量也會随着 search epoch 變大而越來越多,如下圖中的綠線所示。

DARTS+:DARTS 搜尋為何需要早停?

在一個節點數固定的 cell 中,skip-connect 的數量越多,會導緻網絡變得越淺。相比于深度網絡,淺度網絡可學習的參數更少,具有的表達能力更弱。是以,在 DARTS 搜出的網絡架構中,skip-connect 的數量太多會導緻性能急劇變差。例如,在上圖中,當 skip-connect 的數量超過 2 個的時候,網絡的性能(藍線)開始降低。下圖直覺展示了随着 search epoch 變大,網絡結構由深變淺的過程。

DARTS+:DARTS 搜尋為何需要早停?

不同 search epoch 的情形下,在 CIFAR100 上用 DARTS 挑選出的網絡結構圖

DARTS 發生 Collapse 背後的原因是在兩層優化中,alpha 和 w 的更新過程存在先合作(cooperation)後競争(competition)的問題。粗略來說,在剛開始更新的時候,alpha 和 w 是一起被優化,進而 alpha 和 w 都是越變越好。漸漸地,兩者開始變成競争關系,由于 w 在競争中比 alpha 更有優勢(比如,w 的參數量大于 alpha 的參數量,One-Shot 模型在大多數 alpha 下都能收斂,等等),alpha 開始被抑制,是以網絡架構出現了先變好後變差的結果,也就是上上圖中藍線的情況。

具體來說,在搜尋過程的初始階段,One-Shot 模型欠拟合到資料集,是以在搜尋過程剛開始的時候,alpha 和 w(也就是 One-Shot 模型的參數)都會朝着變好的方向更新,這就是合作的階段。由于整個 One-Shot 模型中,前面的 cell 比後面的 cell 能接觸到更幹淨的資料,如果我們允許不同的 cell 可以擁有不同的網絡結構(打破 DARTS 中 cell 共享網絡結構的設定),那麼前面的 cell 會比後面的 cell 更快地學到特征。

一旦前面的 cell 已經學到了不錯的特征表達,而後面的 cell 學到的特征表達相對較差,那麼後面的 cell 接下來會傾向于選擇 skip-connect,來把前面 cell 已經學好的特征表達直接傳遞到後面。下圖是打破 DARTS 中 cell 共享網絡結構的設定下,搜出來的網絡結構圖:可以看到,前面的 cell 大部分都是卷積算子,而靠後的 cell 大部分都是 skip-connect。

DARTS+:DARTS 搜尋為何需要早停?

打破 cell 共享網絡結構的設定下,不同位置的 cell 搜出來的網絡結構圖

回到 DARTS 的設定,如果我們強制不同的 cell 共享同一個網絡結構,那麼 skip-connect 就會從後面的 cell 擴散到前面的 cell。當 skip-connect 開始顯著變多的時候,合作的階段就轉向了競争的階段:alpha 開始變壞,DARTS 開始 collapse。

值得一提的是,兩層優化中的合作和競争現象在其他應用中(比如 GAN,meta-learning 等)也有被觀察到。以 GAN 為例,一個學好的 discriminator 對訓練一個 generator 是至關重要的 [7],這是 generator 和 discriminator 之間的合作;當輸入資料(fake 或 real)落在低維流形上同時 discriminator 過參數化的時候,discriminator 很容易把生成的 fake data 從 real data 中區分開來,同時 generator 也會因為發生梯度消失導緻無法生成 real data[8],這是 generator 和 discriminator 之間的競争。

DARTS+:引入早停機制

為了解決 DARTS 會 collapse 的問題,防止 skip-connect 産生過多,我們提出一種非常簡單而且行之有效的早停機制,改進後的 DARTS 算法稱之為 DARTS+ 算法。本文中我們仍然遵循 DARTS 中 cell 共享網絡結構的設定,将探索如何打破 cell 網絡結構共享留為 future work。

早停準則:當一個 cell 中出現兩個及兩個以上的 skip-connect 的時候,搜尋過程停止。

DARTS+ 最大的優點就是操作起來非常簡單。相比于其他改進 DARTS 的算法,DARTS+ 隻需要一點點改動就可以顯著地提高性能,同時還能直接減少搜尋時間。

DARTS+:DARTS 搜尋為何需要早停?

上圖中的紅圈代表各個可學習算子(比如卷積)的 alpha 排序不再改變的時間點(具體細節請參看原文)。

由于 alpha 值最大的可學習算子對應最後的網絡會選擇的算子,當 alpha 排序穩定時,這個算子在最後選擇的網絡不會出現變化,這說明 DARTS 的搜尋過程已經充分。從上圖中藍線也能看出,當過了紅圈之後,架構的性能開始出現下降,進而出現 collapse 問題。是以,我們可以選擇在可學習算子 alpha 排序不再改變(圖中紅圈處)的時間點附近早停。當早停準則滿足時(左圖中紅色虛線),基本處于 DARTS 搜尋充分處,是以在早停準則處停止搜尋能夠有效防止 DARTS 發生 collapse。

通過上面的分析,我們可以給出一個稍複雜但更為直接的早停準則:

早停準則*:當各個可學習算子(比如卷積)的 alpha 排序足夠穩定(比如 10 個 epoch 保持不變)的時候,搜尋過程停止。

我們指出,第一個早停準則更便于操作,而當需要更精準的停止或者引入其他的搜尋空間的時候,我們可以用早停準則* 來代替。由于早停機制解決了 DARTS 搜尋中固有存在的問題,是以,它也可以被用在其它基于 DARTS 的算法中來幫助提高進一步性能。

值得一提的是,近來的一些基于 DARTS 改進的算法其實也隐式地使用了早停的想法。

P-DARTS[3] 使用了:1)搜 25 個 epoch 來代替原來的 50 個 epoch,2)在 skip-connects 之後加 dropout,3)手動把 skip-connects 的數目減到 2。

Auto-DeepLab[9] 使用了 20 個 epoch 來訓架構參數 alpha,同時發現更多的 epoch(60,80,100)對性能沒有好處。

PC-DARTS[5] 使用部分通道連接配接來降低搜尋時間,是以搜尋收斂需要引入更多的 epoch,進而仍然搜尋 50 個 epoch 就是一個隐式的早停機制。

實驗驗證

我們在 CIFAR10[10]、CIFAR100[10]、Tiny-ImageNet-200[11] 和 ImageNet[12] 上分類問題進行驗證。在實驗中,我們預設使用第一個早停準則。具體的實作細節,請參看原文。

實驗結果如下:

DARTS+ 在 CIFAR10、CIFAR100 和 ImageNet 上取得 2.32%、14.87% 和 23.7% 的錯誤率,超越一系列現有的 DARTS 改進算法,包括 SNAS[2]、P-DARTS[3]、XNAS[4]、PC-DARTS[5] 等。在模型大小相當的情況下,DARTS+ 可以達到與谷歌提出的 EfficientNet-B0[6] 相同的性能,但是搜尋時間卻遠遠小于 EfficientNet。如果再疊加 SE 子產品,mixup 等,在 ImageNet 上可以達到 22.5% 的錯誤率。

具體的性能名額如下所示:

DARTS+:DARTS 搜尋為何需要早停?

CIFAR10 和 CIFAR100 上的實驗結果

DARTS+:DARTS 搜尋為何需要早停?

Tiny-ImageNet-200 上的實驗結果

DARTS+:DARTS 搜尋為何需要早停?

ImageNet 上的實驗結果

結語

綜上所述,DARTS+ 簡單優雅地解決了 DARTS 算法中固有的 collapse 問題,通過引入操作起來十分簡單的早停機制,既縮短了搜尋時間,又極大地提高了性能。想要進一步提升 DARTS 的性能,一個可行的方向是考慮打破 DARTS 中「不同 cell 共享網絡架構」的設定。

本文為機器之心專欄,轉載請聯系本公衆号獲得授權。

繼續閱讀