天天看點

SWA(随機權重平均)

SWA(随機權重平均)

[Averaging Weights Leads to Wider Optima and Better Generalization](Averaging Weights Leads to Wider Optima and Better Generalization)

随機權重平均:在優化的末期取k個優化軌迹上的checkpoints,平均他們的權重,得到最終的網絡權重,這樣就會使得最終的權重位于flat曲面更中心的位置,緩解權重震蕩問題,獲得一個更加平滑的解,相比于傳統訓練有更泛化的解。

SWA(随機權重平均)

效果如下:

SWA(随機權重平均)

SWA和EMA

在EMA指數滑動平均(Exponential Moving Average)我們讨論了指數滑動平均,可以發現SWA和EMA是有相似之處:

  • 都是在訓練之外的操作,不影響訓練過程。
  • 與內建學習類似,都是一種權值的平均,EMA是一種指數平均,會賦予近期更多的權重,SWA則是平均賦權重。

是以這裡參考了的SWA實作,添加了EMA的實作,兩者不同在于影子權值的更新方式。

class WeightAverage(Optimizer):
    def __init__(self, optimizer, wa_start=None, wa_freq=None, wa_lr=None, mode='swa'):
        """實作參考:https://github.com/pytorch/contrib/blob/master/torchcontrib/optim/swa.py
        論文:Averaging Weights Leads to Wider Optima and Better Generalization
        兩種權重平均的方式 swa 和 ema
        兩種模式:自動模式和手動模式
        參數:
            optimizer (torch.optim.Optimizer): optimizer to use with SWA
            wa_start (int): SWA開始應用的step
            wa_freq (int): 更新SWA的頻數
            wa_lr (float): 自動模式:從swa_start開始應用
        """
        if isinstance(mode, float):
            self.mode = 'ema'
            self.beta = mode
        else:
            self.mode = mode
        self._auto_mode, (self.wa_start, self.wa_freq) = self._check_params(wa_start, wa_freq)
        self.wa_lr = wa_lr
        # 參數檢查
        if self._auto_mode:
            if wa_start < 0:
                raise ValueError("Invalid wa_start: {}".format(wa_start))
            if wa_freq < 1:
                raise ValueError("Invalid wa_freq: {}".format(wa_freq))
        else:
            if self.wa_lr is not None:
                warnings.warn("Some of wa_start, wa_freq is None, ignoring wa_lr")
            self.wa_lr = None
            self.wa_start = None
            self.wa_freq = None

        if self.wa_lr is not None and self.wa_lr < 0:
            raise ValueError("Invalid WA learning rate: {}".format(wa_lr))

        self.optimizer = optimizer
        self.defaults = self.optimizer.defaults
        self.param_groups = self.optimizer.param_groups
        self.state = defaultdict(dict)
        self.opt_state = self.optimizer.state

        for group in self.param_groups:
            # ema 不需要儲存已經平均的個數,為了相容swa不修改
            group['n_avg'] = 0
            group['step_counter'] = 0

    @staticmethod
    def _check_params(swa_start, swa_freq):
        """檢查參數,确認執行模式,并将參數轉為int
        """
        params = [swa_start, swa_freq]
        params_none = [param is None for param in params]
        if not all(params_none) and any(params_none):
            warnings.warn("Some of swa_start, swa_freq is None, ignoring other")
        for i, param in enumerate(params):
            if param is not None and not isinstance(param, int):
                params[i] = int(param)
                warnings.warn("Casting swa_start, swa_freq to int")
        return not any(params_none), params

    def _reset_lr_to_swa(self):
        """應用wa學習率
        """
        if self.wa_lr is None:
            return
        for param_group in self.param_groups:
            if param_group['step_counter'] >= self.wa_start:
                param_group['lr'] = self.wa_lr

    def update_swa_group(self, group):
        """更新一組參數的wa: 随機權重平均或者指數滑動平均
        """
        for p in group['params']:
            param_state = self.state[p]
            if 'wa_buffer' not in param_state:
                param_state['wa_buffer'] = torch.zeros_like(p.data)
            buf = param_state['wa_buffer']
            if self.mode == 'swa':
                virtual_decay = 1 / float(group["n_avg"] + 1)
                diff = (p.data - buf) * virtual_decay  # buf + (p-buf) / (n+1) = (p + n*buf) / (n+1)
                buf.add_(diff)
            else:
                buf.mul_(self.beta).add_((1-self.beta) * p.data)
        group["n_avg"] += 1

    def update_swa(self):
        """手動模式:更新所有參數的swa
        """
        for group in self.param_groups:
            self.update_swa_group(group)

    def swap_swa_sgd(self):
        """1.交換swa和模型的參數 2.訓練結束時和評估時調用
        """
        for group in self.param_groups:
            for p in group['params']:
                param_state = self.state[p]
                if 'wa_buffer' not in param_state:
                    warnings.warn("WA wasn't applied to param {}; skipping it".format(p))
                    continue
                buf = param_state['wa_buffer']
                tmp = torch.empty_like(p.data)
                tmp.copy_(p.data)
                p.data.copy_(buf)
                buf.copy_(tmp)

    def step(self, closure=None):
        """1.梯度更新 2.如果是自動模式更新swa參數
        """
        self._reset_lr_to_swa()
        loss = self.optimizer.step(closure)
        for group in self.param_groups:
            group["step_counter"] += 1
            steps = group["step_counter"]
            if self._auto_mode:
                if steps > self.wa_start and steps % self.wa_freq == 0:
                    self.update_swa_group(group)
        return loss

    def state_dict(self):
        """打包 opt_state 優化器狀态,swa_state SWA狀态,param_groups 參數組
        """
        opt_state_dict = self.optimizer.state_dict()
        wa_state = {(id(k) if isinstance(k, torch.Tensor) else k): v
                     for k, v in self.state.items()}
        opt_state = opt_state_dict["state"]
        param_groups = opt_state_dict["param_groups"]
        return {"opt_state": opt_state, "wa_state": wa_state,
                "param_groups": param_groups}

    def load_state_dict(self, state_dict):
        """加載swa和優化器的狀态參數
        """
        wa_state_dict = {"state": state_dict["wa_state"],
                         "param_groups": state_dict["param_groups"]}
        opt_state_dict = {"state": state_dict["opt_state"],
                          "param_groups": state_dict["param_groups"]}
        super(WeightAverage, self).load_state_dict(wa_state_dict)
        self.optimizer.load_state_dict(opt_state_dict)
        self.opt_state = self.optimizer.state

    def add_param_group(self, param_group):
        """将一組參數添加到優化器的 `param_groups`.
        """
        param_group['n_avg'] = 0
        param_group['step_counter'] = 0
        self.optimizer.add_param_group(param_group)

    @staticmethod
    def bn_update(loader, model, device=None):
        """更新 BatchNorm running_mean, running_var
        """
        if not _check_bn(model):
            return
        was_training = model.training
        model.train()
        momenta = {}
        model.apply(_reset_bn)
        model.apply(lambda module: _get_momenta(module, momenta))
        n = 0
        for input in loader:
            if isinstance(input, (list, tuple)):
                input = input[0]
            b = input.size(0)  # batch_size

            momentum = b / float(n + b)
            for module in momenta.keys():
                module.momentum = momentum

            if device is not None:
                input = input.to(device)

            model(input)
            n += b

        model.apply(lambda module: _set_momenta(module, momenta))
        model.train(was_training)


# BatchNorm utils
def _check_bn_apply(module, flag):
    if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
        flag[0] = True


def _check_bn(model):
    flag = [False]
    model.apply(lambda module: _check_bn_apply(module, flag))
    return flag[0]


def _reset_bn(module):
    if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
        module.running_mean = torch.zeros_like(module.running_mean)
        module.running_var = torch.ones_like(module.running_var)


def _get_momenta(module, momenta):
    if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
        momenta[module] = module.momentum


def _set_momenta(module, momenta):
    if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
        module.momentum = momenta[module]