import numpy as np from scipy.ndimage import gaussian_filter as sp_gaussian_filter from scipy.special import logsumexp import torch from torch.optim.optimizer import required import torch.nn as nn from tqdm import tqdm from .models import sample_from_logdensity from .torch_utils import gaussian_filter def sample_batch_fixations(log_density, fixations_per_image, batch_size, rst=None): xs, ys = sample_from_logdensity(log_density, fixations_per_image * batch_size, rst=rst) ns = np.repeat(np.arange(batch_size, dtype=int), repeats=fixations_per_image) return xs, ys, ns class DistributionSGD(torch.optim.Optimizer): """Extension of SGD that constraints the parameters to be nonegative and with fixed sum (e.g., a probability distribution)""" def __init__(self, params, lr=required): if lr is not required and lr < 0.0: raise ValueError("Invalid learning rate: {}".format(lr)) defaults = dict(lr=lr) super(DistributionSGD, self).__init__(params, defaults) @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None: continue d_p = p.grad learning_rate = group['lr'] # constraint_grad = torch.ones_like(d_p) constraint_grad_norm = torch.sum(torch.pow(constraint_grad, 2)) normed_constraint_grad = constraint_grad / constraint_grad_norm # first step: make sure we are not running into negative values max_allowed_grad = p / learning_rate projected_grad1 = torch.min(d_p, max_allowed_grad) # second step: Make sure that the gradient does not walk # out of the constraint projected_grad2 = projected_grad1 - torch.sum(projected_grad1 * constraint_grad) * normed_constraint_grad p.add_(projected_grad2, alpha=-group['lr']) return loss def build_fixation_maps(Ns, Ys, Xs, batch_size, height, width, dtype=torch.float32): indices = torch.stack((Ns, Ys, Xs), axis=1).T src = torch.ones(indices.shape[1], dtype=dtype, device=indices.device) fixation_maps = torch.sparse_coo_tensor(indices, src, size=(batch_size, height, width)).to_dense() return fixation_maps def torch_similarity(saliency_map, empirical_saliency_maps): normalized_empirical_saliency_maps = empirical_saliency_maps / torch.sum(empirical_saliency_maps, dim=[1, 2], keepdim=True) normalized_saliency_map = saliency_map / torch.sum(saliency_map) minimums = torch.min(normalized_empirical_saliency_maps, normalized_saliency_map[None, :, :]) similarities = torch.sum(minimums, dim=[1, 2]) return similarities def compute_similarity(saliency_map, ns, ys, xs, batch_size, kernel_size, truncate_gaussian, dtype=torch.float32): height, width = saliency_map.shape fixation_maps = build_fixation_maps(ns, ys, xs, batch_size, height, width, dtype=dtype) empirical_saliency_maps = gaussian_filter( fixation_maps[:, None, :, :], dim=[2, 3], sigma=kernel_size, truncate=truncate_gaussian, padding_mode='constant', padding_value=0.0, )[:, 0, :, :] similarities = torch_similarity(saliency_map, empirical_saliency_maps) return similarities class Similarities(nn.Module): def __init__(self, initial_saliency_map, kernel_size, truncate_gaussian=3, dtype=torch.float32): super().__init__() self.saliency_map = nn.Parameter(torch.tensor(initial_saliency_map, dtype=dtype), requires_grad=True) self.kernel_size = kernel_size self.truncate_gaussian = truncate_gaussian self.dtype = dtype def forward(self, ns, ys, xs, batch_size): similarities = compute_similarity( self.saliency_map, ns, ys, xs, batch_size, self.kernel_size, self.truncate_gaussian, dtype=self.dtype, ) return similarities def _eval_metric(log_density, test_samples, fn, seed=42, fixation_count=120, batch_size=50, verbose=True): values = [] weights = [] count = 0 rst = np.random.RandomState(seed=seed) with tqdm(total=test_samples, leave=False, disable=not verbose) as t: while count < test_samples: this_count = min(batch_size, test_samples - count) xs, ys, ns = sample_batch_fixations(log_density, fixations_per_image=fixation_count, batch_size=this_count, rst=rst) values.append(fn(ns, ys, xs, this_count)) weights.append(this_count) count += this_count t.update(this_count) weights = np.asarray(weights, dtype=np.float64) / np.sum(weights) return np.average(values, weights=weights) def maximize_expected_sim(log_density, kernel_size, train_samples_per_epoch, val_samples, train_seed=43, val_seed=42, fixation_count=100, batch_size=50, max_batch_size=None, verbose=True, session_config=None, initial_learning_rate=1e-7, backlook=1, min_iter=0, max_iter=1000, truncate_gaussian=3, learning_rate_decay_samples=None, initial_saliency_map=None, learning_rate_decay_scheme=None, learning_rate_decay_ratio=0.333333333, minimum_learning_rate=1e-11): """ max_batch_size: maximum possible batch size to be used in validation learning rate decay samples: how often to decay the learning rate (using 1/k) learning_rate_decay_scheme: how to decay the learning rate: - None, "1/k": 1/k scheme - "validation_loss": if validation loss not better for last backlook steps learning_rate_decay_ratio: how much to decay learning rate if `learning_rate_decay_scheme` == 'validation_loss' minimum_learning_rate: stop optimization if learning rate would drop below this rate if using validation loss decay scheme """ if max_batch_size is None: max_batch_size = batch_size if learning_rate_decay_scheme is None: learning_rate_decay_scheme = '1/k' if learning_rate_decay_samples is None: learning_rate_decay_samples = train_samples_per_epoch log_density_sum = logsumexp(log_density) if not -0.001 < log_density_sum < 0.001: raise ValueError("Log density not normalized! LogSumExp={}".format(log_density_sum)) if initial_saliency_map is None: initial_value = sp_gaussian_filter(np.exp(log_density), kernel_size, mode='constant') else: initial_value = initial_saliency_map if initial_value.min() < 0: initial_value -= initial_value.min() initial_value /= initial_value.sum() dtype = torch.float32 model = Similarities( initial_saliency_map=initial_value, kernel_size=kernel_size, truncate_gaussian=truncate_gaussian, dtype=dtype ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Using device", device) model.to(device) optimizer = DistributionSGD(model.parameters(), lr=initial_learning_rate) if learning_rate_decay_scheme == '1/k': def lr_lambda(epoch): return 1.0 / (max(epoch, 1)) scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer=optimizer, lr_lambda=lr_lambda, ) elif learning_rate_decay_scheme == 'validation_loss': scheduler = torch.optim.lr_scheduler.ExponentialLR( optimizer=optimizer, gamma=learning_rate_decay_ratio, ) else: raise ValueError(learning_rate_decay_scheme) height, width = log_density.shape def _val_loss(ns, ys, xs, batch_size): model.eval() Ns = torch.tensor(ns).to(device) Ys = torch.tensor(ys).to(device) Xs = torch.tensor(xs).to(device) batch_size = torch.tensor(batch_size).to(device) ret = -torch.mean(model(Ns, Ys, Xs, batch_size)).detach().cpu().numpy() return ret def val_loss(): return _eval_metric(log_density, val_samples, _val_loss, seed=val_seed, fixation_count=fixation_count, batch_size=max_batch_size, verbose=False) total_samples = 0 decay_step = 0 val_scores = [val_loss()] learning_rate_relevant_scores = list(val_scores) train_rst = np.random.RandomState(seed=train_seed) with tqdm(disable=not verbose) as outer_t: def general_termination_condition(): return len(val_scores) - 1 >= max_iter def termination_1overk(): return not (np.argmin(val_scores) >= len(val_scores) - backlook) def termination_validation(): return optimizer.state_dict()['param_groups'][0]['lr'] < minimum_learning_rate def termination_condition(): if len(val_scores) < min_iter: return False cond = general_termination_condition() if learning_rate_decay_scheme == '1/k': cond = cond or termination_1overk() elif learning_rate_decay_scheme == 'validation_loss': cond = cond or termination_validation() return cond while not termination_condition(): count = 0 with tqdm(total=train_samples_per_epoch, leave=False, disable=True) as t: while count < train_samples_per_epoch: model.train() optimizer.zero_grad() this_count = min(batch_size, train_samples_per_epoch - count) xs, ys, ns = sample_batch_fixations(log_density, fixations_per_image=fixation_count, batch_size=this_count, rst=train_rst) Ns = torch.tensor(ns).to(device) Ys = torch.tensor(ys).to(device) Xs = torch.tensor(xs).to(device) batch_size = torch.tensor(batch_size).to(device) loss = -torch.mean(model(Ns, Ys, Xs, batch_size)) loss.backward() optimizer.step() with torch.no_grad(): if torch.sum(model.saliency_map < 0): model.saliency_map.mul_(model.saliency_map >= 0) model.saliency_map.div_(torch.sum(model.saliency_map)) count += this_count total_samples += this_count if learning_rate_decay_scheme == '1/k': if total_samples >= (decay_step + 1) * learning_rate_decay_samples: decay_step += 1 scheduler.step() t.update(this_count) val_scores.append(val_loss()) learning_rate_relevant_scores.append(val_scores[-1]) if learning_rate_decay_scheme == 'validation_loss' and np.argmin(learning_rate_relevant_scores) < len(learning_rate_relevant_scores) - backlook: scheduler.step() learning_rate_relevant_scores = [learning_rate_relevant_scores[-1]] score1, score2 = val_scores[-2:] last_min = len(val_scores) - np.argmin(val_scores) - 1 outer_t.set_description('{:.05f}->{:.05f}, diff {:.02e}, best val {} steps ago, lr {:.02e}'.format(val_scores[0], score2, score2 - score1, last_min, optimizer.state_dict()['param_groups'][0]['lr'])) outer_t.update(1) return model.saliency_map.detach().cpu().numpy(), val_scores[-1]