import torch.optim as optim from .bgd_optimizer import BGD def bgd(model, **kwargs): logger = kwargs.get("logger", None) assert(logger is not None) bgd_params = { "mean_eta": kwargs.get("mean_eta", 1), "std_init": kwargs.get("std_init", 0.02), "mc_iters": kwargs.get("mc_iters", 10) } logger.info("BGD params: " + str(bgd_params)) all_params = [{'params': params} for l, (name, params) in enumerate(model.named_parameters())] return BGD(all_params, **bgd_params) def sgd(model, **kwargs): logger = kwargs.get("logger", None) assert(logger is not None) sgd_params = { "momentum": kwargs.get("momentum", 0.9), "lr": kwargs.get("lr", 0.1), "weight_decay": kwargs.get("weight_decay", 5e-4) } logger.info("SGD params: " + str(sgd_params)) all_params = [{'params': params, 'name': name, 'initial_lr': kwargs.get("lr", 0.1)} for l, (name, params) in enumerate(model.named_parameters())] return optim.SGD(all_params, **sgd_params) def adam(model, **kwargs): logger = kwargs.get("logger", None) assert(logger is not None) adam_params = { "eps": kwargs.get("eps", 1e-08), "lr": kwargs.get("lr", 0.001), "betas": kwargs.get("betas", (0.9, 0.999)), "weight_decay": kwargs.get("weight_decay", 0) } logger.info("ADAM params: " + str(adam_params)) all_params = [{'params': params, 'name': name, 'initial_lr': kwargs.get("lr", 0.001)} for l, (name, params) in enumerate(model.named_parameters())] return optim.Adam(all_params, **adam_params) def adagrad(model, **kwargs): logger = kwargs.get("logger", None) assert(logger is not None) adam_params = { "lr": kwargs.get("lr", 0.01), "weight_decay": kwargs.get("weight_decay", 0) } logger.info("Adagrad params: " + str(adam_params)) all_params = [{'params': params, 'name': name, 'initial_lr': kwargs.get("lr", 0.01)} for l, (name, params) in enumerate(model.named_parameters())] return optim.Adagrad(all_params, **adam_params)