Source code for lmp.util.optim

"""Optimization utilities."""

import torch

import lmp.util.validate
from lmp.model import BaseModel


[docs]def get_optimizer( beta1: float, beta2: float, eps: float, lr: float, model: BaseModel, weight_decay: float, ) -> torch.optim.AdamW: """Get AdamW optimizer. Parameters ---------- beta1: float First coefficient of gradient moving average. beta2: float Second coefficient of gradient moving average. eps: float Numerically saved computation term. lr: float Learning rate of gradient descent. model: lmp.model.BaseModel Language model to be optimized. weight_decay: float Weight decay coefficient. Returns ------- torch.optim.AdamW Language model optimizer. See Also -------- torch.optim.AdamW AdamW algorithm. """ # `beta1` validation. lmp.util.validate.raise_if_not_instance(val=beta1, val_name='beta1', val_type=float) lmp.util.validate.raise_if_wrong_ordered(vals=[0.0, beta1, 1.0], val_names=['0.0', 'beta1', '1.0']) # `beta2` validation. lmp.util.validate.raise_if_not_instance(val=beta2, val_name='beta2', val_type=float) lmp.util.validate.raise_if_wrong_ordered(vals=[0.0, beta2, 1.0], val_names=['0.0', 'beta2', '1.0']) # `eps` validation. lmp.util.validate.raise_if_not_instance(val=eps, val_name='eps', val_type=float) lmp.util.validate.raise_if_wrong_ordered(vals=[0.0, eps], val_names=['0.0', 'eps']) # `lr` validation. lmp.util.validate.raise_if_not_instance(val=lr, val_name='lr', val_type=float) lmp.util.validate.raise_if_wrong_ordered(vals=[0.0, lr], val_names=['0.0', 'lr']) # `model` validation. lmp.util.validate.raise_if_not_instance(val=model, val_name='model', val_type=BaseModel) # `weight_decay` validation. lmp.util.validate.raise_if_not_instance(val=weight_decay, val_name='weight_decay', val_type=float) lmp.util.validate.raise_if_wrong_ordered(vals=[0.0, weight_decay], val_names=['0.0', 'weight_decay']) # Remove weight decay on bias and layer-norm. This can only be done after moving model to running device. no_decay = ['bias', 'LayerNorm.weight'] optim_group_params = [ { 'params': [param for name, param in model.named_parameters() if not any(nd in name for nd in no_decay)], 'weight_decay': weight_decay, }, { 'params': [param for name, param in model.named_parameters() if any(nd in name for nd in no_decay)], 'weight_decay': 0.0, }, ] # Get new optimizer instance. return torch.optim.AdamW(optim_group_params, betas=(beta1, beta2), eps=eps, lr=lr)
[docs]def get_scheduler(optim: torch.optim.AdamW, total_step: int, warmup_step: int) -> torch.optim.lr_scheduler.LambdaLR: """Get linearly decay scheduler with linearly warm up. Learning rate will first linearly increase (warm up) to the specified value, then linearly decay to ``0``. Parameters ---------- optim: torch.optim.AdamW Optimizer to be scheduled. total_step: int Total training step. warmup_step: int Learning rate warmup step. Returns ------- torch.optim.lr_scheduler.LambdaLR Optimizer learning rate scheduler. """ # `optim` validation. lmp.util.validate.raise_if_not_instance(val=optim, val_name='optim', val_type=torch.optim.AdamW) # `total_step` and `warmup_step` validation. lmp.util.validate.raise_if_wrong_ordered( vals=[1, warmup_step, total_step], val_names=['1', 'warmup_step', 'total_step'], ) def lr_lambda(step: int) -> float: # Warm up phase. if step < warmup_step: return float(step / max(1, warmup_step)) # Decay phase. return float(max(0, (total_step - step) / max(1, total_step - warmup_step))) return torch.optim.lr_scheduler.LambdaLR(optimizer=optim, lr_lambda=lr_lambda, last_epoch=-1)