SWA(随機權重平均)
[Averaging Weights Leads to Wider Optima and Better Generalization](Averaging Weights Leads to Wider Optima and Better Generalization)
随機權重平均:在優化的末期取k個優化軌迹上的checkpoints,平均他們的權重,得到最終的網絡權重,這樣就會使得最終的權重位于flat曲面更中心的位置,緩解權重震蕩問題,獲得一個更加平滑的解,相比于傳統訓練有更泛化的解。
效果如下:
SWA和EMA
在EMA指數滑動平均(Exponential Moving Average)我們讨論了指數滑動平均,可以發現SWA和EMA是有相似之處:
- 都是在訓練之外的操作,不影響訓練過程。
- 與內建學習類似,都是一種權值的平均,EMA是一種指數平均,會賦予近期更多的權重,SWA則是平均賦權重。
是以這裡參考了的SWA實作,添加了EMA的實作,兩者不同在于影子權值的更新方式。
class WeightAverage(Optimizer):
def __init__(self, optimizer, wa_start=None, wa_freq=None, wa_lr=None, mode='swa'):
"""實作參考:https://github.com/pytorch/contrib/blob/master/torchcontrib/optim/swa.py
論文:Averaging Weights Leads to Wider Optima and Better Generalization
兩種權重平均的方式 swa 和 ema
兩種模式:自動模式和手動模式
參數:
optimizer (torch.optim.Optimizer): optimizer to use with SWA
wa_start (int): SWA開始應用的step
wa_freq (int): 更新SWA的頻數
wa_lr (float): 自動模式:從swa_start開始應用
"""
if isinstance(mode, float):
self.mode = 'ema'
self.beta = mode
else:
self.mode = mode
self._auto_mode, (self.wa_start, self.wa_freq) = self._check_params(wa_start, wa_freq)
self.wa_lr = wa_lr
# 參數檢查
if self._auto_mode:
if wa_start < 0:
raise ValueError("Invalid wa_start: {}".format(wa_start))
if wa_freq < 1:
raise ValueError("Invalid wa_freq: {}".format(wa_freq))
else:
if self.wa_lr is not None:
warnings.warn("Some of wa_start, wa_freq is None, ignoring wa_lr")
self.wa_lr = None
self.wa_start = None
self.wa_freq = None
if self.wa_lr is not None and self.wa_lr < 0:
raise ValueError("Invalid WA learning rate: {}".format(wa_lr))
self.optimizer = optimizer
self.defaults = self.optimizer.defaults
self.param_groups = self.optimizer.param_groups
self.state = defaultdict(dict)
self.opt_state = self.optimizer.state
for group in self.param_groups:
# ema 不需要儲存已經平均的個數,為了相容swa不修改
group['n_avg'] = 0
group['step_counter'] = 0
@staticmethod
def _check_params(swa_start, swa_freq):
"""檢查參數,确認執行模式,并将參數轉為int
"""
params = [swa_start, swa_freq]
params_none = [param is None for param in params]
if not all(params_none) and any(params_none):
warnings.warn("Some of swa_start, swa_freq is None, ignoring other")
for i, param in enumerate(params):
if param is not None and not isinstance(param, int):
params[i] = int(param)
warnings.warn("Casting swa_start, swa_freq to int")
return not any(params_none), params
def _reset_lr_to_swa(self):
"""應用wa學習率
"""
if self.wa_lr is None:
return
for param_group in self.param_groups:
if param_group['step_counter'] >= self.wa_start:
param_group['lr'] = self.wa_lr
def update_swa_group(self, group):
"""更新一組參數的wa: 随機權重平均或者指數滑動平均
"""
for p in group['params']:
param_state = self.state[p]
if 'wa_buffer' not in param_state:
param_state['wa_buffer'] = torch.zeros_like(p.data)
buf = param_state['wa_buffer']
if self.mode == 'swa':
virtual_decay = 1 / float(group["n_avg"] + 1)
diff = (p.data - buf) * virtual_decay # buf + (p-buf) / (n+1) = (p + n*buf) / (n+1)
buf.add_(diff)
else:
buf.mul_(self.beta).add_((1-self.beta) * p.data)
group["n_avg"] += 1
def update_swa(self):
"""手動模式:更新所有參數的swa
"""
for group in self.param_groups:
self.update_swa_group(group)
def swap_swa_sgd(self):
"""1.交換swa和模型的參數 2.訓練結束時和評估時調用
"""
for group in self.param_groups:
for p in group['params']:
param_state = self.state[p]
if 'wa_buffer' not in param_state:
warnings.warn("WA wasn't applied to param {}; skipping it".format(p))
continue
buf = param_state['wa_buffer']
tmp = torch.empty_like(p.data)
tmp.copy_(p.data)
p.data.copy_(buf)
buf.copy_(tmp)
def step(self, closure=None):
"""1.梯度更新 2.如果是自動模式更新swa參數
"""
self._reset_lr_to_swa()
loss = self.optimizer.step(closure)
for group in self.param_groups:
group["step_counter"] += 1
steps = group["step_counter"]
if self._auto_mode:
if steps > self.wa_start and steps % self.wa_freq == 0:
self.update_swa_group(group)
return loss
def state_dict(self):
"""打包 opt_state 優化器狀态,swa_state SWA狀态,param_groups 參數組
"""
opt_state_dict = self.optimizer.state_dict()
wa_state = {(id(k) if isinstance(k, torch.Tensor) else k): v
for k, v in self.state.items()}
opt_state = opt_state_dict["state"]
param_groups = opt_state_dict["param_groups"]
return {"opt_state": opt_state, "wa_state": wa_state,
"param_groups": param_groups}
def load_state_dict(self, state_dict):
"""加載swa和優化器的狀态參數
"""
wa_state_dict = {"state": state_dict["wa_state"],
"param_groups": state_dict["param_groups"]}
opt_state_dict = {"state": state_dict["opt_state"],
"param_groups": state_dict["param_groups"]}
super(WeightAverage, self).load_state_dict(wa_state_dict)
self.optimizer.load_state_dict(opt_state_dict)
self.opt_state = self.optimizer.state
def add_param_group(self, param_group):
"""将一組參數添加到優化器的 `param_groups`.
"""
param_group['n_avg'] = 0
param_group['step_counter'] = 0
self.optimizer.add_param_group(param_group)
@staticmethod
def bn_update(loader, model, device=None):
"""更新 BatchNorm running_mean, running_var
"""
if not _check_bn(model):
return
was_training = model.training
model.train()
momenta = {}
model.apply(_reset_bn)
model.apply(lambda module: _get_momenta(module, momenta))
n = 0
for input in loader:
if isinstance(input, (list, tuple)):
input = input[0]
b = input.size(0) # batch_size
momentum = b / float(n + b)
for module in momenta.keys():
module.momentum = momentum
if device is not None:
input = input.to(device)
model(input)
n += b
model.apply(lambda module: _set_momenta(module, momenta))
model.train(was_training)
# BatchNorm utils
def _check_bn_apply(module, flag):
if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
flag[0] = True
def _check_bn(model):
flag = [False]
model.apply(lambda module: _check_bn_apply(module, flag))
return flag[0]
def _reset_bn(module):
if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
module.running_mean = torch.zeros_like(module.running_mean)
module.running_var = torch.ones_like(module.running_var)
def _get_momenta(module, momenta):
if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
momenta[module] = module.momentum
def _set_momenta(module, momenta):
if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
module.momentum = momenta[module]