Python torch.optim.Optimizer() Examples

The following are 30 code examples of torch.optim.Optimizer(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch.optim , or try the search function .
Example #1
Source File: scheduler.py    From seamseg with BSD 3-Clause "New" or "Revised" License 7 votes vote down vote up
def __init__(self, optimizer, last_epoch=-1):
        if not isinstance(optimizer, Optimizer):
            raise TypeError('{} is not an Optimizer'.format(
                type(optimizer).__name__))
        self.optimizer = optimizer
        if last_epoch == -1:
            for group in optimizer.param_groups:
                group.setdefault('initial_lr', group['lr'])
        else:
            for i, group in enumerate(optimizer.param_groups):
                if 'initial_lr' not in group:
                    raise KeyError("param 'initial_lr' is not specified "
                                   "in param_groups[{}] when resuming an optimizer".format(i))
        self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
        self.step(last_epoch + 1)
        self.last_epoch = last_epoch 
Example #2
Source File: builders.py    From joeynmt with Apache License 2.0 6 votes vote down vote up
def __init__(self, hidden_size: int, optimizer: torch.optim.Optimizer,
                 factor: float = 1, warmup: int = 4000):
        """
        Warm-up, followed by learning rate decay.

        :param hidden_size:
        :param optimizer:
        :param factor: decay factor
        :param warmup: number of warmup steps
        """
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.factor = factor
        self.hidden_size = hidden_size
        self._rate = 0 
Example #3
Source File: train.py    From few-shot with MIT License 6 votes vote down vote up
def gradient_step(model: Module, optimiser: Optimizer, loss_fn: Callable, x: torch.Tensor, y: torch.Tensor, **kwargs):
    """Takes a single gradient step.

    # Arguments
        model: Model to be fitted
        optimiser: Optimiser to calculate gradient step from loss
        loss_fn: Loss function to calculate between predictions and outputs
        x: Input samples
        y: Input targets
    """
    model.train()
    optimiser.zero_grad()
    y_pred = model(x)
    loss = loss_fn(y_pred, y)
    loss.backward()
    optimiser.step()

    return loss, y_pred 
Example #4
Source File: builders.py    From joeynmt with Apache License 2.0 6 votes vote down vote up
def __init__(self, optimizer: torch.optim.Optimizer,
                 peak_rate: float = 1.0e-3,
                 decay_length: int = 10000, warmup: int = 4000,
                 decay_rate: float = 0.5, min_rate: float = 1.0e-5):
        """
        Warm-up, followed by exponential learning rate decay.

        :param peak_rate: maximum learning rate at peak after warmup
        :param optimizer:
        :param decay_length: decay length after warmup
        :param decay_rate: decay rate after warmup
        :param warmup: number of warmup steps
        :param min_rate: minimum learning rate
        """
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.decay_length = decay_length
        self.peak_rate = peak_rate
        self._rate = 0
        self.decay_rate = decay_rate
        self.min_rate = min_rate 
Example #5
Source File: base_runner.py    From mmcv with Apache License 2.0 6 votes vote down vote up
def current_lr(self):
        """Get current learning rates.

        Returns:
            list[float] | dict[str, list[float]]: Current learning rates of all
                param groups. If the runner has a dict of optimizers, this
                method will return a dict.
        """
        if isinstance(self.optimizer, torch.optim.Optimizer):
            lr = [group['lr'] for group in self.optimizer.param_groups]
        elif isinstance(self.optimizer, dict):
            lr = dict()
            for name, optim in self.optimizer.items():
                lr[name] = [group['lr'] for group in optim.param_groups]
        else:
            raise RuntimeError(
                'lr is not applicable because optimizer does not exist.')
        return lr 
Example #6
Source File: checkpointing.py    From visdial-challenge-starter-pytorch with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def __init__(
        self,
        model,
        optimizer,
        checkpoint_dirpath,
        step_size=1,
        last_epoch=-1,
        **kwargs,
    ):

        if not isinstance(model, nn.Module):
            raise TypeError("{} is not a Module".format(type(model).__name__))

        if not isinstance(optimizer, optim.Optimizer):
            raise TypeError(
                "{} is not an Optimizer".format(type(optimizer).__name__)
            )

        self.model = model
        self.optimizer = optimizer
        self.ckpt_dirpath = Path(checkpoint_dirpath)
        self.step_size = step_size
        self.last_epoch = last_epoch
        self.init_directory(**kwargs) 
Example #7
Source File: lr_scheduler.py    From pytorch-planet-amazon with Apache License 2.0 6 votes vote down vote up
def __init__(self, optimizer, last_epoch=-1):
        if not isinstance(optimizer, Optimizer):
            raise TypeError('{} is not an Optimizer'.format(
                type(optimizer).__name__))
        self.optimizer = optimizer
        if last_epoch == -1:
            for group in optimizer.param_groups:
                group.setdefault('initial_lr', group['lr'])
        else:
            for i, group in enumerate(optimizer.param_groups):
                if 'initial_lr' not in group:
                    raise KeyError("param 'initial_lr' is not specified "
                                   "in param_groups[{}] when resuming an optimizer".format(i))
        self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
        self.step(last_epoch + 1)
        self.last_epoch = last_epoch 
Example #8
Source File: model.py    From deep_pipe with MIT License 6 votes vote down vote up
def optimizer_step(optimizer: Optimizer, loss: torch.Tensor, **params) -> torch.Tensor:
    """
    Performs the backward pass with respect to ``loss``, as well as a gradient step.

    ``params`` is used to change the optimizer's parameters.

    Examples
    --------
    >>> optimizer = Adam(model.parameters(), lr=1)
    >>> optimizer_step(optimizer, loss) # perform a gradient step
    >>> optimizer_step(optimizer, loss, lr=1e-3) # set lr to 1e-3 and perform a gradient step
    >>> optimizer_step(optimizer, loss, betas=(0, 0)) # set betas to 0 and perform a gradient step

    Notes
    -----
    The incoming ``optimizer``'s parameters are not restored to their original values.
    """
    set_params(optimizer, **params)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss 
Example #9
Source File: lr_scheduler.py    From virtex with MIT License 6 votes vote down vote up
def __init__(
        self,
        optimizer: Optimizer,
        total_steps: int,
        warmup_steps: int,
        milestones: List[int],
        gamma: float = 0.1,
        last_epoch: int = -1,
    ):
        self.wsteps = warmup_steps
        self.milestones = milestones
        self.gamma = gamma

        # Keep a track of number of milestones encountered.
        self.milestones_so_far = 0

        # Common sanity checks.
        assert milestones == sorted(milestones), "milestones must be increasing"
        assert milestones[0] > warmup_steps, "first milestone must be after warmup"
        assert (
            milestones[-1] < total_steps
        ), "last milestone must be less than total steps"

        super().__init__(optimizer, self._lr_multiplier, last_epoch) 
Example #10
Source File: cyclic_lr.py    From aivivn-tone with MIT License 6 votes vote down vote up
def __init__(self, optimizer, last_epoch=-1):
        if not isinstance(optimizer, Optimizer):
            raise TypeError('{} is not an Optimizer'.format(
                type(optimizer).__name__))
        self.optimizer = optimizer
        if last_epoch == -1:
            for group in optimizer.param_groups:
                group.setdefault('initial_lr', group['lr'])
            last_epoch = 0
        else:
            for i, group in enumerate(optimizer.param_groups):
                if 'initial_lr' not in group:
                    raise KeyError("param 'initial_lr' is not specified "
                                   "in param_groups[{}] when resuming an optimizer".format(i))
        self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
        self.last_epoch = last_epoch
        self.step(last_epoch) 
Example #11
Source File: model.py    From SlowFast-Network-pytorch with MIT License 6 votes vote down vote up
def load(self, path_to_checkpoint: str, optimizer: Optimizer = None, scheduler: _LRScheduler = None) -> 'Model':
        checkpoint = torch.load(path_to_checkpoint)
        self.load_state_dict(checkpoint['state_dict'])

        # model_dict = self.state_dict()
        # pretrained_dict = {k: v for k, v in checkpoint.items() if k in model_dict}  # filter out unnecessary keys
        # model_dict.update(pretrained_dict)
        # self.load_state_dict(model_dict)
        # torch.nn.DataParallel(self).cuda()
        #step = checkpoint['step']
        step=0
        # if optimizer is not None:
        #     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        # if scheduler is not None:
        #     scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        return step 
Example #12
Source File: checkpointing.py    From updown-baseline with MIT License 5 votes vote down vote up
def __init__(
        self,
        models: Union[nn.Module, Dict[str, nn.Module]],
        optimizer: Type[optim.Optimizer],
        serialization_dir: str,
        mode: str = "max",
        filename_prefix: str = "checkpoint",
    ):

        # Convert single model to a dict.
        if isinstance(models, nn.Module):
            models = {"model": models}

        for key in models:
            if not isinstance(models[key], nn.Module):
                raise TypeError("{} is not a Module".format(type(models).__name__))

        if not isinstance(optimizer, optim.Optimizer):
            raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__))

        self._models = models
        self._optimizer = optimizer
        self._serialization_dir = serialization_dir

        self._mode = mode
        self._filename_prefix = filename_prefix

        # Initialize members to hold state dict of best checkpoint and its performance.
        self._best_metric: Optional[Union[float, torch.Tensor]] = None
        self._best_ckpt: Dict[str, Any] = {} 
Example #13
Source File: modelwrapper.py    From baal with Apache License 2.0 5 votes vote down vote up
def train_on_dataset(self, dataset, optimizer, batch_size, epoch, use_cuda, workers=4,
                         collate_fn: Optional[Callable] = None,
                         regularizer: Optional[Callable] = None):
        """
        Train for `epoch` epochs on a Dataset `dataset.

        Args:
            dataset (Dataset): Pytorch Dataset to be trained on.
            optimizer (optim.Optimizer): Optimizer to use.
            batch_size (int): The batch size used in the DataLoader.
            epoch (int): Number of epoch to train for.
            use_cuda (bool): Use cuda or not.
            workers (int): Number of workers for the multiprocessing.
            collate_fn (Optional[Callable]): The collate function to use.
            regularizer (Optional[Callable]): The loss regularization for training.

        Returns:
            The training history.
        """
        self.train()
        history = []
        log.info("Starting training", epoch=epoch, dataset=len(dataset))
        collate_fn = collate_fn or default_collate
        for _ in range(epoch):
            self._reset_metrics('train')
            for data, target in DataLoader(dataset, batch_size, True, num_workers=workers,
                                           collate_fn=collate_fn):
                _ = self.train_on_batch(data, target, optimizer, use_cuda, regularizer)
            history.append(self.metrics['train_loss'].value)

        optimizer.zero_grad()  # Assert that the gradient is flushed.
        log.info("Training complete", train_loss=self.metrics['train_loss'].value)
        return history 
Example #14
Source File: modelwrapper.py    From baal with Apache License 2.0 5 votes vote down vote up
def train_on_batch(self, data, target, optimizer, cuda=False,
                       regularizer: Optional[Callable] = None):
        """
        Train the current model on a batch using `optimizer`.

        Args:
            data (Tensor): The model input.
            target (Tensor): The ground truth.
            optimizer (optim.Optimizer): An optimizer.
            cuda (bool): Use CUDA or not.
            regularizer (Optional[Callable]): The loss regularization for training.


        Returns:
            Tensor, the loss computed from the criterion.
        """

        if cuda:
            data, target = to_cuda(data), to_cuda(target)
        optimizer.zero_grad()
        output = self.model(data)
        loss = self.criterion(output, target)

        if regularizer:
            regularized_loss = loss + regularizer()
            regularized_loss.backward()
        else:
            loss.backward()

        optimizer.step()
        self._update_metrics(output, target, loss, filter='train')
        return loss 
Example #15
Source File: DoubleDQNAgent.py    From DeepRL with MIT License 5 votes vote down vote up
def __init__(
            self,
            _model: nn.Module,
            _env: EnvAbstract,
            _gamma: float, _batch_size: int,
            _epsilon_init: float,
            _epsilon_decay: float,
            _epsilon_underline: float,
            _replay: ReplayAbstract = None,
            _optimizer: optim.Optimizer = None,
            _err_clip: float = None, _grad_clip: float = None
    ):
        super().__init__(_env)

        self.q_func: nn.Module = _model
        self.target_q_func: nn.Module = deepcopy(_model)
        for param in self.target_q_func.parameters():
            param.requires_grad_(False)

        # set config
        self.config.gamma = _gamma
        self.config.batch_size = _batch_size
        self.config.epsilon = _epsilon_init
        self.config.epsilon_decay = _epsilon_decay
        self.config.epsilon_underline = _epsilon_underline
        self.config.err_clip = _err_clip
        self.config.grad_clip = _grad_clip

        self.replay = _replay

        self.criterion = nn.MSELoss()
        self.optimizer = _optimizer 
Example #16
Source File: test_setter.py    From skorch with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def optimizer_dummy(self):
        from torch.optim import Optimizer
        optim = Mock(spec=Optimizer)
        optim.param_groups = [
            {'lr': 0.01, 'momentum': 0.9},
            {'lr': 0.02, 'momentum': 0.9}
        ]
        return optim 
Example #17
Source File: LBFGS.py    From PyTorch-LBFGS with MIT License 5 votes vote down vote up
def step(self, p_k, g_Ok, g_Sk=None, options={}):
        return self._step(p_k, g_Ok, g_Sk, options)

#%% Full-Batch (Deterministic) L-BFGS Optimizer (Wrapper) 
Example #18
Source File: model.py    From easy-fpn.pytorch with MIT License 5 votes vote down vote up
def save(self, path_to_checkpoints_dir: str, step: int, optimizer: Optimizer, scheduler: _LRScheduler) -> str:
        path_to_checkpoint = os.path.join(path_to_checkpoints_dir, f'model-{step}.pth')
        checkpoint = {
            'state_dict': self.state_dict(),
            'step': step,
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict()
        }
        torch.save(checkpoint, path_to_checkpoint)
        return path_to_checkpoint 
Example #19
Source File: model.py    From easy-fpn.pytorch with MIT License 5 votes vote down vote up
def load(self, path_to_checkpoint: str, optimizer: Optimizer = None, scheduler: _LRScheduler = None) -> 'Model':
        checkpoint = torch.load(path_to_checkpoint)
        self.load_state_dict(checkpoint['state_dict'])
        step = checkpoint['step']
        if optimizer is not None:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        if scheduler is not None:
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        return step 
Example #20
Source File: reinforce_learn_Qnet.py    From pytorch-lightning with Apache License 2.0 5 votes vote down vote up
def configure_optimizers(self) -> List[Optimizer]:
        """Initialize Adam optimizer"""
        optimizer = optim.Adam(self.net.parameters(), lr=self.lr)
        return [optimizer] 
Example #21
Source File: carlini.py    From fast_adversarial with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def _step(self, model: nn.Module, optimizer: optim.Optimizer, inputs: torch.Tensor, tinputs: torch.Tensor,
              modifier: torch.Tensor, labels: torch.Tensor, labels_infhot: torch.Tensor, targeted: bool,
              const: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

        batch_size = inputs.shape[0]
        adv_input = torch.tanh(tinputs + modifier) * self.boxmul + self.boxplus
        l2 = (adv_input - inputs).view(batch_size, -1).pow(2).sum(1)

        logits = model(adv_input)

        real = logits.gather(1, labels.unsqueeze(1)).squeeze(1)
        other = (logits - labels_infhot).max(1)[0]
        if targeted:
            # if targeted, optimize for making the other class most likely
            logit_dists = torch.clamp(other - real + self.confidence, min=0)
        else:
            # if non-targeted, optimize for making this class least likely.
            logit_dists = torch.clamp(real - other + self.confidence, min=0)

        logit_loss = (const * logit_dists).sum()
        l2_loss = l2.sum()
        loss = logit_loss + l2_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        return adv_input.detach(), logits.detach(), l2.detach(), logit_dists.detach(), loss.detach() 
Example #22
Source File: train_config.py    From neural-pipeline with MIT License 5 votes vote down vote up
def __init__(self, model: Module, train_stages: [], loss: Module, optimizer: Optimizer):
        self._train_stages = train_stages
        self._loss = loss
        self._optimizer = optimizer
        self._model = model 
Example #23
Source File: trainer.py    From snorkel with Apache License 2.0 5 votes vote down vote up
def _set_optimizer(self, model: nn.Module) -> None:
        optimizer_config = self.config.optimizer_config
        optimizer_name = self.config.optimizer

        parameters = filter(lambda p: p.requires_grad, model.parameters())

        optimizer: optim.Optimizer  # type: ignore

        if optimizer_name == "sgd":
            optimizer = optim.SGD(  # type: ignore
                parameters,
                lr=self.config.lr,
                weight_decay=self.config.l2,
                **optimizer_config.sgd_config._asdict(),
            )
        elif optimizer_name == "adam":
            optimizer = optim.Adam(
                parameters,
                lr=self.config.lr,
                weight_decay=self.config.l2,
                **optimizer_config.adam_config._asdict(),
            )
        elif optimizer_name == "adamax":
            optimizer = optim.Adamax(  # type: ignore
                parameters,
                lr=self.config.lr,
                weight_decay=self.config.l2,
                **optimizer_config.adamax_config._asdict(),
            )
        else:
            raise ValueError(f"Unrecognized optimizer option '{optimizer_name}'")

        logging.info(f"Using optimizer {optimizer}")

        self.optimizer = optimizer 
Example #24
Source File: label_model.py    From snorkel with Apache License 2.0 5 votes vote down vote up
def _set_optimizer(self) -> None:
        parameters = filter(lambda p: p.requires_grad, self.parameters())

        optimizer_config = self.train_config.optimizer_config
        optimizer_name = self.train_config.optimizer
        optimizer: optim.Optimizer  # type: ignore

        if optimizer_name == "sgd":
            optimizer = optim.SGD(  # type: ignore
                parameters,
                lr=self.train_config.lr,
                weight_decay=self.train_config.l2,
                **optimizer_config.sgd_config._asdict(),
            )
        elif optimizer_name == "adam":
            optimizer = optim.Adam(
                parameters,
                lr=self.train_config.lr,
                weight_decay=self.train_config.l2,
                **optimizer_config.adam_config._asdict(),
            )
        elif optimizer_name == "adamax":
            optimizer = optim.Adamax(  # type: ignore
                parameters,
                lr=self.train_config.lr,
                weight_decay=self.train_config.l2,
                **optimizer_config.adamax_config._asdict(),
            )
        else:
            raise ValueError(f"Unrecognized optimizer option '{optimizer_name}'")

        self.optimizer = optimizer 
Example #25
Source File: lamb.py    From pytorch-lamb with MIT License 5 votes vote down vote up
def log_lamb_rs(optimizer: Optimizer, event_writer: SummaryWriter, token_count: int):
    """Log a histogram of trust ratio scalars in across layers."""
    results = collections.defaultdict(list)
    for group in optimizer.param_groups:
        for p in group['params']:
            state = optimizer.state[p]
            for i in ('weight_norm', 'adam_norm', 'trust_ratio'):
                if i in state:
                    results[i].append(state[i])

    for k, v in results.items():
        event_writer.add_histogram(f'lamb/{k}', torch.tensor(v), token_count) 
Example #26
Source File: lr_scheduler.py    From pytorch-planet-amazon with Apache License 2.0 5 votes vote down vote up
def __init__(self, optimizer, mode='min', factor=0.1, patience=10,
                 verbose=False, threshold=1e-4, threshold_mode='rel',
                 cooldown=0, min_lr=0, eps=1e-8):

        if factor >= 1.0:
            raise ValueError('Factor should be < 1.0.')
        self.factor = factor

        if not isinstance(optimizer, Optimizer):
            raise TypeError('{} is not an Optimizer'.format(
                type(optimizer).__name__))
        self.optimizer = optimizer

        if isinstance(min_lr, list) or isinstance(min_lr, tuple):
            if len(min_lr) != len(optimizer.param_groups):
                raise ValueError("expected {} min_lrs, got {}".format(
                    len(optimizer.param_groups), len(min_lr)))
            self.min_lrs = list(min_lr)
        else:
            self.min_lrs = [min_lr] * len(optimizer.param_groups)

        self.patience = patience
        self.verbose = verbose
        self.cooldown = cooldown
        self.cooldown_counter = 0
        self.mode = mode
        self.threshold = threshold
        self.threshold_mode = threshold_mode
        self.best = None
        self.num_bad_epochs = None
        self.mode_worse = None  # the worse value for the chosen mode
        self.is_better = None
        self.eps = eps
        self.last_epoch = -1
        self._init_is_better(mode=mode, threshold=threshold,
                             threshold_mode=threshold_mode)
        self._reset() 
Example #27
Source File: model_builder.py    From lumin with Apache License 2.0 5 votes vote down vote up
def get_model(self) -> Tuple[nn.Module, optim.Optimizer, Any]:
        r'''
        Construct model, loss, and optimiser, optionally loading pretrained weights

        Returns:
            Instantiated network, optimiser linked to model parameters, and uninstantiated loss
        '''

        model = self.build_model()
        if self.pretrain_file is not None: self.load_pretrained(model)
        model = to_device(model)
        opt = self._build_opt(model)
        return model, opt, self.loss, self.input_mask 
Example #28
Source File: model_builder.py    From lumin with Apache License 2.0 5 votes vote down vote up
def _build_opt(self, model:nn.Module) -> optim.Optimizer:
        if isinstance(self.opt, str): self.opt = self._interp_opt(self.opt)  # Backwards compatability with pre-v0.3.1 saves
        return self.opt(model.parameters(), **self.opt_args) 
Example #29
Source File: model_builder.py    From lumin with Apache License 2.0 5 votes vote down vote up
def _interp_opt(opt:str) -> Callable[[Iterator, Optional[Any]], optim.Optimizer]:
        opt = opt.lower()
        if   opt == 'adam':   return optim.Adam
        elif opt == 'sgd':    return optim.SGD
        else: raise ValueError(f"Optimiser {opt} not interpretable from string, please pass as class") 
Example #30
Source File: __init__.py    From pytorch-saltnet with MIT License 5 votes vote down vote up
def create_lr_scheduler(optimizer, lr_scheduler, **kwargs):
    if not isinstance(optimizer, optim.Optimizer):
        # assume FP16_Optimizer
        optimizer = optimizer.optimizer

    if lr_scheduler == 'plateau':
        patience = kwargs.get('lr_scheduler_patience', 10) // kwargs.get('validation_interval', 1)
        factor = kwargs.get('lr_scheduler_gamma', 0.1)
        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=patience, factor=factor, eps=0)
    elif lr_scheduler == 'step':
        step_size = kwargs['lr_scheduler_step_size']
        gamma = kwargs.get('lr_scheduler_gamma', 0.1)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
    elif lr_scheduler == 'cos':
        max_epochs = kwargs['max_epochs']
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_epochs)
    elif lr_scheduler == 'milestones':
        milestones = kwargs['lr_scheduler_milestones']
        gamma = kwargs.get('lr_scheduler_gamma', 0.1)
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=gamma)
    elif lr_scheduler == 'findlr':
        max_steps = kwargs['max_steps']
        lr_scheduler = FindLR(optimizer, max_steps)
    elif lr_scheduler == 'noam':
        warmup_steps = kwargs['lr_scheduler_warmup']
        lr_scheduler = NoamLR(optimizer, warmup_steps=warmup_steps)
    elif lr_scheduler == 'clr':
        step_size = kwargs['lr_scheduler_step_size']
        learning_rate = kwargs['learning_rate']
        lr_scheduler_gamma = kwargs['lr_scheduler_gamma']
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                                  T_max=step_size,
                                                                  eta_min=learning_rate * lr_scheduler_gamma)
    else:
        raise NotImplementedError("unknown lr_scheduler " + lr_scheduler)
    return lr_scheduler