Source code for machine.optim.optim

import itertools

import torch


[docs]class Optimizer(object): """ The Optimizer class encapsulates torch.optim package and provides functionalities for learning rate scheduling and gradient norm clipping. Args: optim (torch.optim.Optimizer): optimizer object, the parameters to be optimized should be given when instantiating the object, e.g. torch.optim.SGD(params) max_grad_norm (float, optional): value used for gradient norm clipping, set 0 to disable (default 0) """ _ARG_MAX_GRAD_NORM = 'max_grad_norm' def __init__(self, optim, max_grad_norm=0): self.optimizer = optim self.scheduler = None self.max_grad_norm = max_grad_norm
[docs] def set_scheduler(self, scheduler): """ Set the learning rate scheduler. Args: scheduler (torch.optim.lr_scheduler.*): object of learning rate scheduler, e.g. torch.optim.lr_scheduler.StepLR """ self.scheduler = scheduler
[docs] def step(self): """ Performs a single optimization step, including gradient norm clipping if necessary. """ if self.max_grad_norm > 0: params = itertools.chain.from_iterable( [group['params'] for group in self.optimizer.param_groups]) torch.nn.utils.clip_grad_norm_(params, self.max_grad_norm) self.optimizer.step()
[docs] def update(self, loss, epoch): """ Update the learning rate if the criteria of the scheduler are met. Args: loss (float): The current loss. It could be training loss or developing loss depending on the caller. By default the supervised trainer uses developing loss. epoch (int): The current epoch number. """ if self.scheduler is None: pass elif isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): self.scheduler.step(loss) else: self.scheduler.step()