Python lib.utils.get_logger() Examples
The following are 1
code examples of lib.utils.get_logger().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
lib.utils
, or try the search function
.
Example #1
Source File: ToyExperiments.py From UMNN with BSD 3-Clause "New" or "Revised" License | 4 votes |
def train_toy(toy, load=True, nb_steps=20, nb_flow=1, folder=""): device = "cpu" logger = utils.get_logger(logpath=os.path.join(folder, toy, 'logs'), filepath=os.path.abspath(__file__)) logger.info("Creating model...") model = UMNNMAFFlow(nb_flow=nb_flow, nb_in=2, hidden_derivative=[100, 100, 100, 100], hidden_embedding=[100, 100, 100, 100], embedding_s=10, nb_steps=nb_steps, device=device).to(device) logger.info("Model created.") opt = torch.optim.Adam(model.parameters(), 1e-3, weight_decay=1e-5) if load: logger.info("Loading model...") model.load_state_dict(torch.load(folder + toy+'/model.pt')) model.train() opt.load_state_dict(torch.load(folder + toy+'/ADAM.pt')) logger.info("Model loaded.") nb_samp = 100 batch_size = 100 x_test = torch.tensor(toy_data.inf_train_gen(toy, batch_size=1000)).to(device) x = torch.tensor(toy_data.inf_train_gen(toy, batch_size=1000)).to(device) for epoch in range(10000): ll_tot = 0 start = timer() for j in range(0, nb_samp, batch_size): cur_x = torch.tensor(toy_data.inf_train_gen(toy, batch_size=batch_size)).to(device) ll, z = model.compute_ll(cur_x) ll = -ll.mean() ll_tot += ll.detach()/(nb_samp/batch_size) loss = ll opt.zero_grad() loss.backward() opt.step() end = timer() ll_test, _ = model.compute_ll(x_test) ll_test = -ll_test.mean() logger.info("epoch: {:d} - Train loss: {:4f} - Test loss: {:4f} - Elapsed time per epoch {:4f} (seconds)". format(epoch, ll_tot.item(), ll_test.item(), end-start)) if (epoch % 100) == 0: summary_plots(x, x_test, folder, epoch, model, ll_tot, ll_test) torch.save(model.state_dict(), folder + toy + '/model.pt') torch.save(opt.state_dict(), folder + toy + '/ADAM.pt')