Python torch.optim.AdamW() Examples

The following are 15 code examples of torch.optim.AdamW(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch.optim , or try the search function .
Example #1
Source File: train.py    From aitextgen with MIT License 6 votes vote down vote up
def configure_optimizers(self):
        "Prepare optimizer"

        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p
                    for n, p in self.model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay": self.hparams["weight_decay"],
            },
            {
                "params": [
                    p
                    for n, p in self.model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay": 0.0,
            },
        ]
        optimizer = AdamW(
            optimizer_grouped_parameters,
            lr=self.hparams["learning_rate"],
            eps=self.hparams["adam_epsilon"],
        )

        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.hparams["warmup_steps"],
            num_training_steps=self.hparams["num_steps"],
        )

        return [optimizer], [scheduler] 
Example #2
Source File: test_qhadam.py    From qhoptim with MIT License 6 votes vote down vote up
def test_adam_equiv():
    lr = 3e-4
    betas = (0.9, 0.999)
    weight_decay = 0.5e-4
    eps = 1e-8

    def adam_ctor(params):
        return Adam(params, lr=lr, betas=betas, weight_decay=weight_decay, eps=eps)

    def qhadam_ctor(params):
        return QHAdam(params, lr=lr, betas=betas, weight_decay=weight_decay, nus=(1.0, 1.0), eps=eps)

    def adamw_ctor(params):
        return AdamW(params, lr=lr, betas=betas, weight_decay=weight_decay, eps=eps)

    def qhadamw_ctor(params):
        return QHAdamW(params, lr=lr, betas=betas, weight_decay=weight_decay, nus=(1.0, 1.0), eps=eps)

    assert_optimizers_equal(adam_ctor, qhadam_ctor) 
Example #3
Source File: ppo_map.py    From doom-net-pytorch with MIT License 6 votes vote down vote up
def __init__(self, args):
        self.model = BaseModel(
            args.screen_size[0]*args.frame_num, args.button_num, args.variable_num, args.frame_num, args.batch_size
        ).to(device)
        if args.load is not None:
            # load weights
            state_dict = torch.load(args.load)
            self.model.load_state_dict(state_dict)

        self.discount = args.episode_discount
        self.steps = []
        self.rewards = []
        self.non_terminals = []
        self.non_terminal = torch.ones(args.batch_size, 1)

        self.cells = Cells(2, self.model.screen_feature_num, args.batch_size)
        self.init_cells = self.cells.clone()

        self.optimizer = optim.AdamW(self.model.parameters(), lr=args.learning_rate,  weight_decay=1e-6, amsgrad=True)
        '''
        if args.load is not None and os.path.isfile(args.load + '_optimizer.pth'):
            optimizer_dict = torch.load(args.load+'_optimizer.pth')
            optimizer.load_state_dict(optimizer_dict)
        '''
        self.optimizer.zero_grad()
        self.args = args 
Example #4
Source File: misc.py    From self-critical.pytorch with MIT License 6 votes vote down vote up
def build_optimizer(params, opt):
    if opt.optim == 'rmsprop':
        return optim.RMSprop(params, opt.learning_rate, opt.optim_alpha, opt.optim_epsilon, weight_decay=opt.weight_decay)
    elif opt.optim == 'adagrad':
        return optim.Adagrad(params, opt.learning_rate, weight_decay=opt.weight_decay)
    elif opt.optim == 'sgd':
        return optim.SGD(params, opt.learning_rate, weight_decay=opt.weight_decay)
    elif opt.optim == 'sgdm':
        return optim.SGD(params, opt.learning_rate, opt.optim_alpha, weight_decay=opt.weight_decay)
    elif opt.optim == 'sgdmom':
        return optim.SGD(params, opt.learning_rate, opt.optim_alpha, weight_decay=opt.weight_decay, nesterov=True)
    elif opt.optim == 'adam':
        return optim.Adam(params, opt.learning_rate, (opt.optim_alpha, opt.optim_beta), opt.optim_epsilon, weight_decay=opt.weight_decay)
    elif opt.optim == 'adamw':
        return optim.AdamW(params, opt.learning_rate, (opt.optim_alpha, opt.optim_beta), opt.optim_epsilon, weight_decay=opt.weight_decay)
    else:
        raise Exception("bad option opt.optim: {}".format(opt.optim)) 
Example #5
Source File: misc.py    From ImageCaptioning.pytorch with MIT License 6 votes vote down vote up
def build_optimizer(params, opt):
    if opt.optim == 'rmsprop':
        return optim.RMSprop(params, opt.learning_rate, opt.optim_alpha, opt.optim_epsilon, weight_decay=opt.weight_decay)
    elif opt.optim == 'adagrad':
        return optim.Adagrad(params, opt.learning_rate, weight_decay=opt.weight_decay)
    elif opt.optim == 'sgd':
        return optim.SGD(params, opt.learning_rate, weight_decay=opt.weight_decay)
    elif opt.optim == 'sgdm':
        return optim.SGD(params, opt.learning_rate, opt.optim_alpha, weight_decay=opt.weight_decay)
    elif opt.optim == 'sgdmom':
        return optim.SGD(params, opt.learning_rate, opt.optim_alpha, weight_decay=opt.weight_decay, nesterov=True)
    elif opt.optim == 'adam':
        return optim.Adam(params, opt.learning_rate, (opt.optim_alpha, opt.optim_beta), opt.optim_epsilon, weight_decay=opt.weight_decay)
    elif opt.optim == 'adamw':
        return optim.AdamW(params, opt.learning_rate, (opt.optim_alpha, opt.optim_beta), opt.optim_epsilon, weight_decay=opt.weight_decay)
    else:
        raise Exception("bad option opt.optim: {}".format(opt.optim)) 
Example #6
Source File: main.py    From AdaMod with Apache License 2.0 5 votes vote down vote up
def create_optimizer(args, model_params):
    if args.optim == 'sgd':
        return optim.SGD(model_params, args.lr, momentum=args.momentum,
                         weight_decay=args.weight_decay)
    elif args.optim == 'adam':
        return optim.AdamW(model_params, args.lr, betas=(args.beta1, args.beta2),
                          weight_decay=args.weight_decay)
    elif args.optim == 'adamod':
        return AdaMod(model_params, args.lr, betas=(args.beta1, args.beta2),
                      beta3=args.beta3, weight_decay=args.weight_decay) 
Example #7
Source File: factory.py    From incremental_learning.pytorch with MIT License 5 votes vote down vote up
def get_optimizer(params, optimizer, lr, weight_decay=0.0):
    if optimizer == "adam":
        return optim.Adam(params, lr=lr, weight_decay=weight_decay)
    elif optimizer == "adamw":
        return optim.AdamW(params, lr=lr, weight_decay=weight_decay)
    elif optimizer == "sgd":
        return optim.SGD(params, lr=lr, weight_decay=weight_decay, momentum=0.9)
    elif optimizer == "sgd_nesterov":
        return optim.SGD(params, lr=lr, weight_decay=weight_decay, momentum=0.9, nesterov=True)

    raise NotImplementedError 
Example #8
Source File: misc.py    From self-critical.pytorch with MIT License 5 votes vote down vote up
def get_std_opt(model, optim_func='adam', factor=1, warmup=2000):
    # return NoamOpt(model.tgt_embed[0].d_model, 2, 4000,
    #         torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
    optim_func = dict(adam=torch.optim.Adam,
                      adamw=torch.optim.AdamW)[optim_func]
    return NoamOpt(model.d_model, factor, warmup,
            optim_func(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) 
Example #9
Source File: misc.py    From ImageCaptioning.pytorch with MIT License 5 votes vote down vote up
def get_std_opt(model, optim_func='adam', factor=1, warmup=2000):
    # return NoamOpt(model.tgt_embed[0].d_model, 2, 4000,
    #         torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
    optim_func = dict(adam=torch.optim.Adam,
                      adamw=torch.optim.AdamW)[optim_func]
    return NoamOpt(model.d_model, factor, warmup,
            optim_func(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) 
Example #10
Source File: optimizer.py    From Auto-PyTorch with Apache License 2.0 5 votes vote down vote up
def _get_optimizer(self, params, config):
        return optim.AdamW(params=params, lr=config['learning_rate'], weight_decay=config['weight_decay']) 
Example #11
Source File: get_optimizer.py    From BiaffineDependencyParsing with MIT License 4 votes vote down vote up
def get_optimizer(args, model):
    logger = get_logger(args.log_name)
    args.warmup_steps = math.ceil(args.warmup_prop * args.max_train_steps)
    if args.optimizer == 'adamw-bertology':
        if args.different_lr:
            optimizer_grouped_parameters = _get_bertology_different_lr_grouped_parameters(args, model)
        else:
            optimizer_grouped_parameters = _get_bertology_optimizer_grouped_parameters(args, model)
        optimizer = huggingfaceOptim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon,
                                           betas=(args.beta1, args.beta2))
        scheduler = huggingfaceOptim.WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps,
                                                          t_total=args.max_train_steps)
        if args.local_rank in [-1, 0]:
            logger.info('Use Huggingface\'s AdamW Optimizer')
    elif args.optimizer == 'adamw-torch':
        try:
            from torch.optim import AdamW
        except ImportError as e:
            debug_print(f'torch version: {torch.__version__}')
            raise e
        if args.different_lr:
            optimizer_grouped_parameters = _get_bertology_different_lr_grouped_parameters(args, model)
        else:
            optimizer_grouped_parameters = _get_bertology_optimizer_grouped_parameters(args, model)
        optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon,
                          betas=(args.beta1, args.beta2))
        scheduler = huggingfaceOptim.WarmupLinearSchedule(optimizer,
                                                          warmup_steps=args.warmup_steps,
                                                          t_total=args.max_train_steps)
    elif args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate)
        scheduler = None
    elif args.optimizer == 'adagrad':
        optimizer = torch.optim.Adagrad(model.parameters(), lr=args.learning_rate)
        scheduler = None
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, betas=args.betas, eps=args.eps,
                                     weight_decay=args.weight_decay)
        scheduler = None
    elif args.rnn_optimizer == 'adamax':
        optimizer = torch.optim.Adamax(model.parameters())  # use default lr
        scheduler = None
    else:
        raise Exception("Unsupported optimizer: {}".format(args.optimizer))
    return optimizer, scheduler 
Example #12
Source File: learner.py    From emmental with MIT License 4 votes vote down vote up
def _set_optimizer(self, model: EmmentalModel) -> None:
        """Set optimizer for learning process.

        Args:
          model: The model to set up the optimizer.
        """
        optimizer_config = Meta.config["learner_config"]["optimizer_config"]
        opt = optimizer_config["optimizer"]

        # If Meta.config["learner_config"]["optimizer_config"]["parameters"] is None,
        # create a parameter group with all parameters in the model, else load user
        # specified parameter groups.
        if optimizer_config["parameters"] is None:
            parameters = filter(lambda p: p.requires_grad, model.parameters())
        else:
            parameters = optimizer_config["parameters"](model)

        optim_dict = {
            # PyTorch optimizer
            "asgd": optim.ASGD,  # type: ignore
            "adadelta": optim.Adadelta,  # type: ignore
            "adagrad": optim.Adagrad,  # type: ignore
            "adam": optim.Adam,  # type: ignore
            "adamw": optim.AdamW,  # type: ignore
            "adamax": optim.Adamax,  # type: ignore
            "lbfgs": optim.LBFGS,  # type: ignore
            "rms_prop": optim.RMSprop,  # type: ignore
            "r_prop": optim.Rprop,  # type: ignore
            "sgd": optim.SGD,  # type: ignore
            "sparse_adam": optim.SparseAdam,  # type: ignore
            # Customize optimizer
            "bert_adam": BertAdam,
        }

        if opt in ["lbfgs", "r_prop", "sparse_adam"]:
            optimizer = optim_dict[opt](
                parameters,
                lr=optimizer_config["lr"],
                **optimizer_config[f"{opt}_config"],
            )
        elif opt in optim_dict.keys():
            optimizer = optim_dict[opt](
                parameters,
                lr=optimizer_config["lr"],
                weight_decay=optimizer_config["l2"],
                **optimizer_config[f"{opt}_config"],
            )
        elif isinstance(opt, optim.Optimizer):  # type: ignore
            optimizer = opt(parameters)
        else:
            raise ValueError(f"Unrecognized optimizer option '{opt}'")

        self.optimizer = optimizer

        if Meta.config["meta_config"]["verbose"]:
            logger.info(f"Using optimizer {self.optimizer}") 
Example #13
Source File: mcts_base.py    From doom-net-pytorch with MIT License 4 votes vote down vote up
def train_model(self, model, args, epoch_num=10):
        dataset = MCTSDataset(args)
        training_data_loader = DataLoader(dataset=dataset, num_workers=1, batch_size=args.batch_size, shuffle=True)

        model.train()
        optimizer = optim.AdamW(model.parameters(), lr=5e-3, weight_decay=1e-4, amsgrad=True)

        mean_value_loss = 0
        mean_policy_loss = 0
        mean_accuracy = 0
        updates = 0

        batch_time = time.time()
        for epoch in range(epoch_num):
            for batch, (state, target_action, target_value) in enumerate(training_data_loader):
                state, target_action, target_value = state.to(device), target_action.to(device), target_value.to(device)

                optimizer.zero_grad()
                value, log_action = model(state)
                value_loss = F.mse_loss(value, target_value[:, None])
                policy_loss = F.nll_loss(log_action, target_action)
                loss = value_loss + policy_loss

                loss.backward()
                optimizer.step()

                grads = []
                weights = []
                for p in model.parameters():
                    if p.grad is not None:
                        grads.append(p.grad.data.view(-1))
                        weights.append(p.data.view(-1))
                grads = torch.cat(grads, 0)
                weights = torch.cat(weights, 0)
                grads_norm = grads.norm()
                weights_norm = weights.norm()

                assert grads_norm == grads_norm

                _, pred_action = log_action.max(1)
                accuracy = (pred_action == target_action.data).float().mean()

                if epoch == epoch_num - 1:
                    mean_value_loss += value_loss.item()
                    mean_policy_loss += policy_loss.item()
                    mean_accuracy += accuracy
                    updates += 1

        mean_value_loss /= updates
        mean_policy_loss /= updates
        mean_accuracy /= updates

        print(
            "value_loss = {:f} policy_loss = {:f} accuracy = {:f}, train_time = {:.3f}".format(mean_value_loss,
                                                                                               mean_policy_loss,
                                                                                               mean_accuracy,
                                                                                               time.time() - batch_time))

        torch.save(model.state_dict(), args.checkpoint_file)
        torch.save(optimizer.state_dict(), args.checkpoint_file + '_optimizer.pth') 
Example #14
Source File: imitation.py    From doom-net-pytorch with MIT License 4 votes vote down vote up
def train(args):

    train_set = DoomDataset(args.h5_path)
    np.save('action_set', train_set.action_sets)
    training_data_loader = DataLoader(dataset=train_set, num_workers=2, batch_size=100, shuffle=True)

    model = BaseModel(train_set.input_shape[0], len(train_set.action_sets), 3, args.frame_num).to(device)

    if args.load is not None and os.path.isfile(args.load):
        print("loading model parameters {}".format(args.load))
        source_model = torch.load(args.load)
        model.load_state_dict(source_model.state_dict())
        del source_model

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=5e-4)

    for epoch in range(1500000):
        running_loss = 0
        running_accuracy = 0
        batch_time = time.time()
        for batch, (screens, variables, labels) in enumerate(training_data_loader):
            screens, variables, labels = screens.to(device), variables.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs = model(screens, variables)[0]
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            _, pred = outputs.max(1)
            accuracy = (pred == labels).float().mean()
            running_accuracy += accuracy

            batches_per_print = 10
            if batch % batches_per_print == batches_per_print-1:  # print every batches_per_print mini-batches
                print(
                    '[{:d}, {:5d}] loss: {:.3f}, accuracy: {:.3f}, time: {:.6f}'.format(
                    epoch + 1, batch + 1, running_loss/batches_per_print, running_accuracy/batches_per_print, (time.time()-batch_time)/batches_per_print
                    )
                )
                running_loss = 0
                running_accuracy = 0
                batch_time = time.time()

        if epoch % args.checkpoint_rate == args.checkpoint_rate - 1:
            torch.save(model, args.checkpoint_file) 
Example #15
Source File: imitation_frames.py    From doom-net-pytorch with MIT License 4 votes vote down vote up
def train(args):

    data_file = h5py.File(args.h5_path, 'r')
    screens = data_file['screens']
    variables = data_file['variables']
    labels = data_file['action_labels']
    print('Dataset size =', len(screens))
    action_sets = data_file['action_sets'][:]
    episodes = data_file['episodes'][:]
    input_shape = screens[0].shape
    train_generator = data_generator(args, screens, variables, labels, episodes, args.skiprate)

    np.save('action_set', action_sets)

    model = BaseModel(input_shape[0]*args.frame_num, len(action_sets), variables.shape[1], args.frame_num).to(device)

    if args.load is not None and os.path.isfile(args.load):
        print("loading model parameters {}".format(args.load))
        source_model = torch.load(args.load)
        model.load_state_dict(source_model.state_dict())
        del source_model

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=5e-4)
    optimizer.zero_grad()
    running_loss = 0
    running_accuracy = 0
    batch_time = time.time()

    for batch, (screens, variables, labels, terminals) in enumerate(train_generator):
        labels = labels.to(device)
        outputs, _ = model(*model.transform_input(screens, variables))
        loss = criterion(outputs, labels)
        model.set_terminal(terminals)

        running_loss += loss.item()
        _, pred = outputs.max(1)
        accuracy = (pred == labels).float().mean()
        running_accuracy += accuracy

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % args.episode_length == args.episode_length - 1:
            running_loss /= args.episode_length
            running_accuracy /= args.episode_length

            print(
                '[{:d}] loss: {:.3f}, accuracy: {:.3f}, time: {:.6f}'.format(
                    batch + 1, running_loss, running_accuracy, time.time()-batch_time
                )
            )
            running_loss = 0
            running_accuracy = 0
            batch_time = time.time()

        if batch % args.checkpoint_rate == args.checkpoint_rate - 1:
            torch.save(model, args.checkpoint_file)