天天看點

神經網絡架構搜尋——可微分搜尋(SGAS)

動機

NAS技術都有一個通病:在搜尋過程中驗證精度較高,但是在實際測試精度卻沒有那麼高。傳統的基于梯度搜尋的DARTS技術,是根據block建構更大的超網,由于搜尋的過程中驗證不充分,最終eval和test精度會出現鴻溝。從下圖的Kendall系數來看,DARTS搜出的網絡精度排名和實際訓練完成的精度排名偏差還是比較大。

神經網絡架構搜尋——可微分搜尋(SGAS)

方法

整體思路

本文使用與DARTS相同的搜尋空間,SGAS搜尋過程簡單易懂,如下圖所示。類似DARTS搜尋過程為每條邊指定參數α,超網訓練時通過文中判定規則逐漸确定每條邊的具體操作,搜尋結束後即可得到最終模型。

神經網絡架構搜尋——可微分搜尋(SGAS)
神經網絡架構搜尋——可微分搜尋(SGAS)

為了保證在貪心搜尋的過程中能盡量保證搜尋的全局最優性,進而引入了三個名額和兩個評估準則。

三個名額

邊的重要性

非零操作參數對應的softmax值求和,作為邊的重要性衡量名額。

$$

S_{E I}^{(i, j)}=\sum_{o \in \mathcal{O}, o \neq z e r o} \frac{\exp \left(\alpha_{o}^{(i, j)}\right)}{\sum_{o^{\prime} \in \mathcal{O}} \exp \left(\alpha_{o^{\prime}}^{(i, j)}\right)}

alphas = []
for i in range(4):
    for n in range(2 + i):
        alphas.append(Variable(1e-3 * torch.randn(8)))
# alphas經過訓練後
mat = F.softmax(torch.stack(alphas, dim=0), dim=-1).detach() # mat為14*8次元的二維清單,softmax歸一化。 
EI = torch.sum(mat[:, 1:], dim=-1) # EI為14個數的一維清單,去掉none後的7個ops對應alpha值相加           
選擇的準确性

計算操作分布的标準化熵,熵越小确定性越高;熵越高确定性越小。

\begin{array}{c}

p_{o}^{(i, j)}=\frac{\exp \left(\alpha_{o}^{(i, j)}\right)}{S_{E I}^{(i, j)} \sum_{o^{\prime} \in \mathcal{O}} \exp \left(\alpha_{o^{\prime}}^{(i, j)}\right)}, o \in \mathcal{O}, o \neq z e r o \\

S_{S C}^{(i, j)}=1-\frac{-\sum_{o \in \mathcal{O}, o \neq z e r o} p_{o}^{(i, j)} \log \left(p_{o}^{(i, j)}\right)}{\log (|\mathcal{O}|-1)}

\end{array}

import torch.distributions.categorical as cate
probs = mat[:, 1:] / EI[:, None]
entropy = cate.Categorical(probs=probs).entropy() / math.log(probs.size()[1])
SC = 1-entropy           
選擇的穩定性

将曆史資訊納入操作分布評估,使用直方圖交叉核計算平均選擇穩定性。直方圖交叉核的原理詳見(

https://blog.csdn.net/hong__fang/article/details/50550656

)。

S_{S S}^{(i, j)}=\frac{1}{K} \sum_{t=T-K}^{T-1} \sum_{o_{t} \in \mathcal{O}, o_{t} \neq z e r o} \min \left(p_{o_{t}}^{(i, j)}, p_{o_{T}}^{(i, j)}\right)

def histogram_intersection(a, b):
  c = np.minimum(a.cpu().numpy(),b.cpu().numpy())
  c = torch.from_numpy(c).cuda()
  sums = c.sum(dim=1)
  return sums

def histogram_average(history, probs):
  histogram_inter = torch.zeros(probs.shape[0], dtype=torch.float).cuda()
  if not history:
    return histogram_inter
  for hist in history:
    histogram_inter += utils.histogram_intersection(hist, probs)
  histogram_inter /= len(history)
  return histogram_inter

probs_history = []

probs_history.append(probs)
if (len(probs_history) > args.history_size):
  probs_history.pop(0)
  
histogram_inter = histogram_average(probs_history, probs)

SS = histogram_inter           

兩種評估準則

評估準則1:

選擇具有高邊緣重要性和高選擇确定性的操作

S_{1}^{(i, j)}=\text { normalize }\left(S_{E I}^{(i, j)}\right) * \text { normalize }\left(S_{S C}^{(i, j)}\right)

def normalize(v):
  min_v = torch.min(v)
  range_v = torch.max(v) - min_v
  if range_v > 0:
    normalized_v = (v - min_v) / range_v
  else:
    normalized_v = torch.zeros(v.size()).cuda()

  return normalized_v

score = utils.normalize(EI) * utils.normalize(SC)           
評估準則2:

在評估準則1的基礎上,加入考慮選擇穩定性

S_{2}^{(i, j)}=S_{1}^{(i, j)} * \text { normalize }\left(S_{S S}^{(i, j)}\right)

score = utils.normalize(EI) * utils.normalize(SC) * utils.normalize(SS)           

實驗結果

CIFAR-10(CNN)

神經網絡架構搜尋——可微分搜尋(SGAS)

ImageNet(CNN)

神經網絡架構搜尋——可微分搜尋(SGAS)

ModelNet40(GCN)

神經網絡架構搜尋——可微分搜尋(SGAS)

PPI(GCN)

神經網絡架構搜尋——可微分搜尋(SGAS)

參考

[1] Li, Guohao et al. ,SGAS: Sequential Greedy Architecture Search

[2]

https://zhuanlan.zhihu.com/p/134294068

[3] 直方圖交叉核

![更多内容關注微信公衆号【AI異構】]

繼續閱讀