import math import os import re import time import torch from sacred import observers from torch import nn from torch.utils import checkpoint import utils from nde import transforms from utils import NoDataRootError class NamingObserver(observers.RunObserver): def __init__(self, basedir, priority): self.basedir = basedir self.priority = priority def started_event(self, ex_info, command, host_info, start_time, config, meta_info, _id): prefix = config['dataset'] if config['run_descr']: prefix += '-' + config['run_descr'] def existing_run_nrs(): pattern = '{}(-\d+)?'.format(prefix) run_dirs = (d for d in os.listdir(self.basedir) if os.path.isdir(os.path.join(self.basedir, d))) for run_dir in run_dirs: match = re.fullmatch(pattern, run_dir) if match: num_str = match.group(1) yield int(num_str[1:] if num_str else 0) max_nr = max(existing_run_nrs(), default=None) if max_nr is None: return prefix else: return prefix + '-{}'.format(max_nr + 1) def imshow(image, ax): image = utils.tensor2numpy(image.permute(1,2,0)) if image.shape[-1] == 1: ax.imshow(1 - image[...,0], cmap='Greys') else: ax.imshow(image) ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.spines['left'].set_visible(False) ax.spines['bottom'].set_visible(False) ax.tick_params(axis='both', length=0) ax.set_xticklabels('') ax.set_yticklabels('') def get_dataset_root(): env_var = 'DATASET_ROOT' try: return os.environ[env_var] except KeyError: raise NoDataRootError("Environment variable {} doesn't exist.".format(env_var)) def eval_log_density(log_prob_fn, data_loader, num_batches=None): with torch.no_grad(): total_ld = 0 batch_counter = 0 for batch in data_loader: if isinstance(batch, list): # If labelled dataset, ignore labels batch = batch[0] log_prob = log_prob_fn(batch) total_ld += torch.mean(log_prob) batch_counter += 1 if (num_batches is not None) and batch_counter == num_batches: break return total_ld / batch_counter def eval_log_density_2(log_prob_fn, data_loader, c, h, w, num_batches=None): with torch.no_grad(): total_ld = [] batch_counter = 0 for batch in data_loader: if isinstance(batch, list): # If labelled dataset, ignore labels batch = batch[0] log_prob = log_prob_fn(batch) total_ld.append(log_prob) batch_counter += 1 if (num_batches is not None) and batch_counter == num_batches: break total_ld = torch.cat(total_ld) total_ld = nats_to_bits_per_dim(total_ld, c, h, w) return total_ld.mean(), 2 * total_ld.std() / total_ld.shape[0] class CheckpointWrapper(transforms.Transform): def __init__(self, transform): super().__init__() self.transform = transform def forward(self, inputs): return checkpoint.checkpoint(self.transform, inputs) def inverse(self, inputs): return self.transform.inverse(inputs) class Conv2dSameSize(nn.Conv2d): def __init__(self, in_channels, out_channels, kernel_size): same_padding = kernel_size // 2 # Padding that would keep the spatial dims the same super().__init__(in_channels, out_channels, kernel_size, padding=same_padding) def descendants_of_type(transform, type): if isinstance(transform, type): return [transform] elif (isinstance(transform, transforms.CompositeTransform) or isinstance(transform, transforms.MultiscaleCompositeTransform)): l = [] for t in transform._transforms: l.extend(descendants_of_type(t, type)) return l else: return [] class Timer: def __init__(self, print=False): self.print = print def __enter__(self): self.start = time.time() return self def __exit__(self, *args): self.end = time.time() self.interval = self.end - self.start if self.print: print('Operation took {:.03f} sec.'.format(self.interval)) # From https://github.com/tqdm/tqdm/blob/master/tqdm/_tqdm.py def format_interval(t): """ Formats a number of seconds as a clock time, [H:]MM:SS Parameters ---------- t : int Number of seconds. Returns ------- out : str [H:]MM:SS """ mins, s = divmod(int(t), 60) h, m = divmod(mins, 60) if h: return '{0:d}:{1:02d}:{2:02d}'.format(h, m, s) else: return '{0:02d}:{1:02d}'.format(m, s) def progress_string(elapsed_time, step, num_steps): rate = step / elapsed_time if rate > 0: remaining_time = format_interval((num_steps - step) / rate) else: remaining_time = '...' elapsed_time = format_interval(elapsed_time) return '{}<{}, {:.2f}it/s'.format(elapsed_time, remaining_time, rate) class LogProbWrapper(nn.Module): def __init__(self, flow): super().__init__() self.flow = flow def forward(self, inputs, context=None): return self.flow.log_prob(inputs, context) def nats_to_bits_per_dim(nats, c, h, w): return nats / (math.log(2) * c * h * w) # https://stackoverflow.com/questions/431684/how-do-i-change-directory-cd-in-python/13197763#13197763 class cd: """Context manager for changing the current working directory""" def __init__(self, newPath): self.newPath = os.path.expanduser(newPath) def __enter__(self): self.savedPath = os.getcwd() os.chdir(self.newPath) def __exit__(self, etype, value, traceback): os.chdir(self.savedPath)