|
| 1 | +""" PyTorch Lamb optimizer w/ behaviour similar to NVIDIA FusedLamb |
| 2 | +
|
| 3 | +This optimizer code was adapted from the following (starting with latest) |
| 4 | +* https://github.com/HabanaAI/Model-References/blob/2b435114fe8e31f159b1d3063b8280ae37af7423/PyTorch/nlp/bert/pretraining/lamb.py |
| 5 | +* https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py |
| 6 | +* https://github.com/cybertronai/pytorch-lamb |
| 7 | +
|
| 8 | +Use FusedLamb if you can. The reason for including this variant of Lamb is to have a version that is |
| 9 | +similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or cannot install APEX for whatever reason. |
| 10 | +
|
| 11 | +Original copyrights for above sources are below. |
| 12 | +""" |
| 13 | +# Copyright (c) 2021, Habana Labs Ltd. All rights reserved. |
| 14 | + |
| 15 | +# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. |
| 16 | +# |
| 17 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 18 | +# you may not use this file except in compliance with the License. |
| 19 | +# You may obtain a copy of the License at |
| 20 | +# |
| 21 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 22 | +# |
| 23 | +# Unless required by applicable law or agreed to in writing, software |
| 24 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 25 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 26 | +# See the License for the specific language governing permissions and |
| 27 | +# limitations under the License. |
| 28 | + |
| 29 | +# MIT License |
| 30 | +# |
| 31 | +# Copyright (c) 2019 cybertronai |
| 32 | +# |
| 33 | +# Permission is hereby granted, free of charge, to any person obtaining a copy |
| 34 | +# of this software and associated documentation files (the "Software"), to deal |
| 35 | +# in the Software without restriction, including without limitation the rights |
| 36 | +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
| 37 | +# copies of the Software, and to permit persons to whom the Software is |
| 38 | +# furnished to do so, subject to the following conditions: |
| 39 | +# |
| 40 | +# The above copyright notice and this permission notice shall be included in all |
| 41 | +# copies or substantial portions of the Software. |
| 42 | +# |
| 43 | +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| 44 | +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| 45 | +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 46 | +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| 47 | +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
| 48 | +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
| 49 | +# SOFTWARE. |
| 50 | + |
| 51 | +import torch |
| 52 | +from torch.optim import Optimizer |
| 53 | + |
| 54 | + |
| 55 | +class NvLamb(Optimizer): |
| 56 | + """Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB |
| 57 | + reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py |
| 58 | +
|
| 59 | + LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. |
| 60 | +
|
| 61 | + Arguments: |
| 62 | + params (iterable): iterable of parameters to optimize or dicts defining |
| 63 | + parameter groups. |
| 64 | + lr (float, optional): learning rate. (default: 1e-3) |
| 65 | + betas (Tuple[float, float], optional): coefficients used for computing |
| 66 | + running averages of gradient and its norm. (default: (0.9, 0.999)) |
| 67 | + eps (float, optional): term added to the denominator to improve |
| 68 | + numerical stability. (default: 1e-8) |
| 69 | + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) |
| 70 | + grad_averaging (bool, optional): whether apply (1-beta2) to grad when |
| 71 | + calculating running averages of gradient. (default: True) |
| 72 | + set_grad_none (bool, optional): whether set grad to None when zero_grad() |
| 73 | + method is called. (default: True) |
| 74 | + max_grad_norm (float, optional): value used to clip global grad norm |
| 75 | + (default: 1.0) |
| 76 | + use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0 |
| 77 | + weight decay parameter (default: False) |
| 78 | +
|
| 79 | + .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes: |
| 80 | + https://arxiv.org/abs/1904.00962 |
| 81 | + .. _On the Convergence of Adam and Beyond: |
| 82 | + https://openreview.net/forum?id=ryQu7f-RZ |
| 83 | + """ |
| 84 | + |
| 85 | + def __init__(self, params, lr=1e-3, bias_correction=True, |
| 86 | + betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01, |
| 87 | + grad_averaging=True, set_grad_none=True, |
| 88 | + max_grad_norm=1.0, use_nvlamb=False): |
| 89 | + defaults = dict(lr=lr, bias_correction=bias_correction, |
| 90 | + betas=betas, eps=eps, weight_decay=weight_decay, |
| 91 | + grad_averaging=grad_averaging, |
| 92 | + max_grad_norm=max_grad_norm) |
| 93 | + super().__init__(params, defaults) |
| 94 | + self.set_grad_none = set_grad_none |
| 95 | + self.use_nvlamb = use_nvlamb |
| 96 | + |
| 97 | + def zero_grad(self): |
| 98 | + if self.set_grad_none: |
| 99 | + for group in self.param_groups: |
| 100 | + for p in group['params']: |
| 101 | + p.grad = None |
| 102 | + else: |
| 103 | + super(NvLamb, self).zero_grad() |
| 104 | + |
| 105 | + def step(self, closure=None): |
| 106 | + """Performs a single optimization step. |
| 107 | + Arguments: |
| 108 | + closure (callable, optional): A closure that reevaluates the model |
| 109 | + and returns the loss. |
| 110 | + """ |
| 111 | + device = self.param_groups[0]["params"][0].device |
| 112 | + |
| 113 | + loss = None |
| 114 | + if closure is not None: |
| 115 | + loss = closure() |
| 116 | + |
| 117 | + global_grad_norm = torch.zeros(1, device=device) |
| 118 | + for group in self.param_groups: |
| 119 | + for p in group['params']: |
| 120 | + if p.grad is None: |
| 121 | + continue |
| 122 | + grad = p.grad.data |
| 123 | + if grad.is_sparse: |
| 124 | + raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') |
| 125 | + global_grad_norm.add_(grad.pow(2).sum()) |
| 126 | + |
| 127 | + global_grad_norm_ = torch.sqrt(global_grad_norm) |
| 128 | + max_grad_norm = self.defaults['max_grad_norm'] |
| 129 | + |
| 130 | + if global_grad_norm_ > max_grad_norm: |
| 131 | + clip_global_grad_norm = global_grad_norm_ / max_grad_norm |
| 132 | + else: |
| 133 | + clip_global_grad_norm = 1.0 |
| 134 | + |
| 135 | + for group in self.param_groups: |
| 136 | + bias_correction = 1 if group['bias_correction'] else 0 |
| 137 | + beta1, beta2 = group['betas'] |
| 138 | + grad_averaging = 1 if group['grad_averaging'] else 0 |
| 139 | + if grad_averaging: |
| 140 | + beta3 = 1 - beta1 |
| 141 | + else: |
| 142 | + beta3 = 1.0 |
| 143 | + |
| 144 | + # assume same step across group now to simplify things |
| 145 | + # per parameter step can be easily support by making it tensor, or pass list into kernel |
| 146 | + if 'step' in group: |
| 147 | + group['step'] += 1 |
| 148 | + else: |
| 149 | + group['step'] = 1 |
| 150 | + |
| 151 | + step_size = group['lr'] |
| 152 | + |
| 153 | + if bias_correction: |
| 154 | + bias_correction1 = 1 - beta1 ** group['step'] |
| 155 | + bias_correction2 = 1 - beta2 ** group['step'] |
| 156 | + else: |
| 157 | + bias_correction1, bias_correction2 = 1.0, 1.0 |
| 158 | + |
| 159 | + for p in group['params']: |
| 160 | + if p.grad is None: |
| 161 | + continue |
| 162 | + grad = p.grad.data.div_(clip_global_grad_norm) |
| 163 | + state = self.state[p] |
| 164 | + |
| 165 | + # State initialization |
| 166 | + if len(state) == 0: |
| 167 | + # Exponential moving average of gradient values |
| 168 | + state['exp_avg'] = torch.zeros_like(p.data) |
| 169 | + # Exponential moving average of squared gradient values |
| 170 | + state['exp_avg_sq'] = torch.zeros_like(p.data) |
| 171 | + |
| 172 | + exp_avg_, exp_avg_sq_ = state['exp_avg'], state['exp_avg_sq'] |
| 173 | + |
| 174 | + # Decay the first and second moment running average coefficient |
| 175 | + # m_t |
| 176 | + exp_avg_.mul_(beta1).add_(grad, alpha=beta3) |
| 177 | + # v_t |
| 178 | + exp_avg_sq_.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) |
| 179 | + # create clones to avoid modifying runner stats |
| 180 | + exp_avg = exp_avg_.div(bias_correction1) |
| 181 | + exp_avg_sq = exp_avg_sq_.div(bias_correction2) |
| 182 | + |
| 183 | + # || w_t || |
| 184 | + weight_norm = p.data.norm(2.0) |
| 185 | + # u_t |
| 186 | + exp_avg_sq_sqrt = torch.sqrt(exp_avg_sq) |
| 187 | + adam_step = exp_avg.div_(exp_avg_sq_sqrt.add_(group['eps'])) |
| 188 | + if group['weight_decay'] != 0: |
| 189 | + adam_step.add_(p.data, alpha=group['weight_decay']) |
| 190 | + # || u_t || |
| 191 | + adam_norm = adam_step.norm(2.0) |
| 192 | + if (group['weight_decay'] != 0 or self.use_nvlamb) and adam_norm > 0 and weight_norm > 0: |
| 193 | + trust_ratio = weight_norm / adam_norm |
| 194 | + trust_ratio = trust_ratio.item() |
| 195 | + else: |
| 196 | + trust_ratio = 1 |
| 197 | + |
| 198 | + state['weight_norm'] = weight_norm |
| 199 | + state['adam_norm'] = adam_norm |
| 200 | + state['trust_ratio'] = trust_ratio |
| 201 | + |
| 202 | + p.data.add_(adam_step, alpha=-step_size * trust_ratio) |
| 203 | + |
| 204 | + return loss |
0 commit comments