# -*- coding: utf-8 -*- # # Imports import torch import numpy as np from .error_measures import nrmse, generalized_squared_cosine from scipy.interpolate import interp1d import numpy.linalg as lin # Compute correlation matrix def compute_correlation_matrix(states): """ Compute correlation matrix :param states: :return: """ return states.t().mm(states) / float(states.size(0)) # end compute_correlation_matrix # Align pattern def align_pattern(interpolation_rate, truth_pattern, generated_pattern): """ Align pattern :param interpolation_rate: :param truth_pattern: :param generated_pattern: :return: """ # Length truth_length = truth_pattern.size(0) generated_length = generated_pattern.size(0) # Remove useless dimension truth_pattern = truth_pattern.view(-1) generated_pattern = generated_pattern.view(-1) # Quadratic interpolation functions truth_pattern_func = interp1d(np.arange(truth_length), truth_pattern.numpy(), kind='quadratic') generated_pattern_func = interp1d(np.arange(generated_length), generated_pattern.numpy(), kind='quadratic') # Get interpolated patterns truth_pattern_int = truth_pattern_func(np.arange(0, truth_length - 1.0, 1.0 / interpolation_rate)) generated_pattern_int = generated_pattern_func(np.arange(0, generated_length - 1.0, 1.0 / interpolation_rate)) # Generated interpolated pattern length L = generated_pattern_int.shape[0] # Truth interpolated pattern length M = truth_pattern_int.shape[0] # Save L2 distance for each phase shift phase_matches = np.zeros(L - M) # For each phase shift for phases_hift in range(L - M): phase_matches[phases_hift] = lin.norm(truth_pattern_int - generated_pattern_int[phases_hift:phases_hift + M]) # end for # Best match max_ind = int(np.argmax(-phase_matches)) # Get the position in the original signal coarse_max_ind = int(np.ceil(max_ind / interpolation_rate)) # Get the generated output matching the original signal generated_aligned = generated_pattern_int[ np.arange(max_ind, max_ind + interpolation_rate * truth_length, interpolation_rate) ] return max_ind, coarse_max_ind, torch.from_numpy(generated_aligned).view(-1, 1) # end align_pattern # Find phase shift def find_phase_shift(p, y, interpolation_rate, error_measure=nrmse): """ Find phase shift :param s1: :param s2: :param window_size: :return: """ # Size p_length = p.size(0) y_length = y.size(0) # 1D p = p.view(-1) y = y.view(-1) # Interpolate p and y p_int = torch.from_numpy(np.interp(np.arange(0, p_length, 1.0 / interpolation_rate), np.arange(p_length), p.numpy())) y_int = torch.from_numpy(np.interp(np.arange(0, y_length, 1.0 / interpolation_rate), np.arange(y_length), y.numpy())) # New shape L = y_int.shape[0] M = p_int.shape[0] # Find best phase phasematches = torch.zeros(L - M) for phaseshift in range(L - M): phasematches[phaseshift] = torch.norm(p_int - y_int[phaseshift:phaseshift + M], p=2) # end for # Best phase max_index = torch.argmax(-phasematches) # Matching phase y_aligned = y_int[np.arange(max_index, max_index + interpolation_rate * p_length, interpolation_rate)] # Original phase original_phase = np.ceil(max_index / interpolation_rate) # Error after alignment error_aligned = error_measure(y_aligned.reshape(1, -1), p.reshape(1, -1)) return p, y_aligned, original_phase, error_aligned # end find_phase_shift # Compute similarity matrix def compute_similarity_matrix(svd_list): """ Compute similarity matrix :param svd_list: :return: """ # N samples n_samples = len(svd_list) # Similarity matrix sim_matrix = torch.zeros(n_samples, n_samples) # For each combinasion for i, (Sa, Ua) in enumerate(svd_list): for j, (Sb, Ub) in enumerate(svd_list): sim_matrix[i, j] = generalized_squared_cosine(Sa, Ua, Sb, Ub) # end for # end for return sim_matrix # end compute_similarity_matrix # Compute singular values def compute_singular_values(stats): """ Compute singular values :param states: :return: """ # Compute R (correlation matrix) R = stats.t().mm(stats) / stats.shape[0] # Compute singular values return torch.svd(R) # end compute_singular_values # Compute spectral radius of a square 2-D tensor def spectral_radius(m): """ Compute spectral radius of a square 2-D tensor :param m: squared 2D tensor :return: """ return torch.max(torch.abs(torch.eig(m)[0])).item() # end spectral_radius # Compute spectral radius of a square 2-D tensor for stacked-ESN def deep_spectral_radius(m, leaky_rate): """ Compute spectral radius of a square 2-D tensor for stacked-ESN :param m: squared 2D tensor :param leaky_rate: Layer's leaky rate :return: """ return spectral_radius((1.0 - leaky_rate) * torch.eye(m.size(0), m.size(0)) + leaky_rate * m) # end spectral_radius # Normalize a tensor on a single dimension def normalize(tensor, dim=1): """ Normalize a tensor on a single dimension :param t: :return: """ pass # end normalize # Average probabilties through time def average_prob(tensor, dim=0): """ Average probabilities through time :param tensor: :param dim: :return: """ return torch.mean(tensor, dim=dim) # end average_prob # Max average through time def max_average_through_time(tensor, dim=0): """ Max average through time :param tensor: :param dim: Time dimension :return: """ average = torch.mean(tensor, dim=dim) return torch.max(average, dim=dim)[1] # end max_average_through_time