Python lib.utils.get_logger() Examples

The following are 1 code examples of lib.utils.get_logger(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module lib.utils , or try the search function .
Example #1
Source File: ToyExperiments.py    From UMNN with BSD 3-Clause "New" or "Revised" License 4 votes vote down vote up
def train_toy(toy, load=True, nb_steps=20, nb_flow=1, folder=""):
    device = "cpu"
    logger = utils.get_logger(logpath=os.path.join(folder, toy, 'logs'), filepath=os.path.abspath(__file__))

    logger.info("Creating model...")
    model = UMNNMAFFlow(nb_flow=nb_flow, nb_in=2, hidden_derivative=[100, 100, 100, 100], hidden_embedding=[100, 100, 100, 100],
                        embedding_s=10, nb_steps=nb_steps, device=device).to(device)
    logger.info("Model created.")
    opt = torch.optim.Adam(model.parameters(), 1e-3, weight_decay=1e-5)

    if load:
        logger.info("Loading model...")
        model.load_state_dict(torch.load(folder + toy+'/model.pt'))
        model.train()
        opt.load_state_dict(torch.load(folder + toy+'/ADAM.pt'))
        logger.info("Model loaded.")

    nb_samp = 100
    batch_size = 100

    x_test = torch.tensor(toy_data.inf_train_gen(toy, batch_size=1000)).to(device)
    x = torch.tensor(toy_data.inf_train_gen(toy, batch_size=1000)).to(device)

    for epoch in range(10000):
        ll_tot = 0
        start = timer()
        for j in range(0, nb_samp, batch_size):
            cur_x = torch.tensor(toy_data.inf_train_gen(toy, batch_size=batch_size)).to(device)
            ll, z = model.compute_ll(cur_x)
            ll = -ll.mean()
            ll_tot += ll.detach()/(nb_samp/batch_size)
            loss = ll
            opt.zero_grad()
            loss.backward()
            opt.step()
        end = timer()
        ll_test, _ = model.compute_ll(x_test)
        ll_test = -ll_test.mean()
        logger.info("epoch: {:d} - Train loss: {:4f} - Test loss: {:4f} - Elapsed time per epoch {:4f} (seconds)".
                    format(epoch, ll_tot.item(), ll_test.item(), end-start))

        if (epoch % 100) == 0:
            summary_plots(x, x_test, folder, epoch, model, ll_tot, ll_test)
            torch.save(model.state_dict(), folder + toy + '/model.pt')
            torch.save(opt.state_dict(), folder + toy + '/ADAM.pt')