diff --git a/src/training/main.py b/src/training/main.py index f70c9f953..4387e48ae 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -31,6 +31,7 @@ from training.data import get_data from training.distributed import is_master, init_distributed_device, broadcast_object from training.logger import setup_logging +from training.optimizers import Lion from training.params import parse_args from training.scheduler import cosine_lr, const_lr, const_lr_cooldown from training.train import train_one_epoch, evaluate @@ -296,15 +297,29 @@ def main(args): gain_or_bias_params = [p for n, p in named_parameters if exclude(n, p) and p.requires_grad] rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad] - optimizer = optim.AdamW( - [ - {"params": gain_or_bias_params, "weight_decay": 0.}, - {"params": rest_params, "weight_decay": args.wd}, - ], - lr=args.lr, - betas=(args.beta1, args.beta2), - eps=args.eps, - ) + if 'lion' in args.opt: + logging.info('Using Lion optimizer.') + optimizer = Lion( + [ + {"params": gain_or_bias_params, "weight_decay": 0.}, + {"params": rest_params, "weight_decay": args.wd}, + ], + lr=args.lr, + betas=(args.beta1, args.beta2), + use_triton='triton' in args.opt, + ) + else: + logging.info('Using adamw optimizer.') + optimizer = optim.AdamW( + [ + {"params": gain_or_bias_params, "weight_decay": 0.}, + {"params": rest_params, "weight_decay": args.wd}, + ], + lr=args.lr, + betas=(args.beta1, args.beta2), + eps=args.eps, + ) + if args.horovod: optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters()) hvd.broadcast_parameters(model.state_dict(), root_rank=0) diff --git a/src/training/optimizers/__init__.py b/src/training/optimizers/__init__.py new file mode 100644 index 000000000..c62008cf4 --- /dev/null +++ b/src/training/optimizers/__init__.py @@ -0,0 +1 @@ +from .lion import Lion \ No newline at end of file diff --git a/src/training/optimizers/lion.py b/src/training/optimizers/lion.py new file mode 100644 index 000000000..0d90db450 --- /dev/null +++ b/src/training/optimizers/lion.py @@ -0,0 +1,89 @@ +# This file is from https://github.com/lucidrains/lion-pytorch/blob/main/lion_pytorch/lion_pytorch.py +from typing import Tuple, Optional, Callable + +import torch +from torch.optim.optimizer import Optimizer + +# functions + +def exists(val): + return val is not None + +# update functions + +def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2): + # stepweight decay + + p.data.mul_(1 - lr * wd) + + # weight update + + update = exp_avg.clone().lerp_(grad, 1 - beta1) + p.add_(torch.sign(update), alpha = -lr) + + # decay the momentum running average coefficient + + exp_avg.lerp_(grad, 1 - beta2) + +# class + +class Lion(Optimizer): + def __init__( + self, + params, + lr: float = 1e-4, + betas: Tuple[float, float] = (0.9, 0.99), + weight_decay: float = 0.0, + use_triton: bool = False + ): + assert lr > 0. + assert all([0. <= beta <= 1. for beta in betas]) + + defaults = dict( + lr = lr, + betas = betas, + weight_decay = weight_decay + ) + + super().__init__(params, defaults) + + self.update_fn = update_fn + + if use_triton: + from lion_pytorch.triton import update_fn as triton_update_fn + self.update_fn = triton_update_fn + + @torch.no_grad() + def step( + self, + closure: Optional[Callable] = None + ): + + loss = None + if exists(closure): + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in filter(lambda p: exists(p.grad), group['params']): + + grad, lr, wd, beta1, beta2, state = p.grad, group['lr'], group['weight_decay'], *group['betas'], self.state[p] + + # init state - exponential moving average of gradient values + + if len(state) == 0: + state['exp_avg'] = torch.zeros_like(p) + + exp_avg = state['exp_avg'] + + self.update_fn( + p, + grad, + exp_avg, + lr, + wd, + beta1, + beta2 + ) + + return loss \ No newline at end of file diff --git a/src/training/params.py b/src/training/params.py index 36c693bc7..ec3b40088 100644 --- a/src/training/params.py +++ b/src/training/params.py @@ -140,6 +140,10 @@ def parse_args(args): parser.add_argument( "--warmup", type=int, default=10000, help="Number of steps to warmup for." ) + parser.add_argument( + "--opt", type=str, default='adamw', + help="Which optimizer to use. Choices are ['adamw', 'lion', 'lion-triton']." + ) parser.add_argument( "--use-bn-sync", default=False,