from .normalization import normalize
import torch
from torch.distributions import Distribution
from math import sqrt
from typing import Union, Tuple, Iterable

EPS = sqrt(torch.finfo(torch.float32).eps)

def get_ess(w: torch.Tensor, normalized=False):
    Calculates the ESS from an array of log weights.
    :param w: The log weights
    :param normalized: Whether input is normalized
    :return: The effective sample size

    if not normalized:
        w = normalize(w)

    return w.sum(-1) ** 2 / (w ** 2).sum(-1)

def choose(array: torch.Tensor, indices: torch.Tensor):
    Function for choosing on either columns or index.
    :param array: The array to choose on
    :param indices: The indices to choose from `array`
    :return: Returns the chosen elements from `array`

    if indices.dim() < 2:
        return array[indices]

    return array[torch.arange(array.shape[0], device=array.device)[:, None], indices]

def loglikelihood(w: torch.Tensor, weights: torch.Tensor = None):
    Calculates the estimated loglikehood given weights.
    :param w: The log weights, corresponding to likelihood
    :param weights: Whether to weight the log-likelihood.
    :return: The log-likelihood

    maxw, _ = w.max(-1)

    # ===== Calculate the second term ===== #
    if weights is None:
        temp = (
            torch.exp(w - (maxw.unsqueeze(-1) if maxw.dim() > 0 else maxw))
        temp = (
            (weights * torch.exp(w - (maxw.unsqueeze(-1) if maxw.dim() > 0 else maxw)))

    return maxw + temp

def concater(*x: Union[Iterable[torch.Tensor], torch.Tensor]) -> torch.Tensor:
    Concatenates output.
    :type x: tuple[torch.Tensor]|torch.Tensor

    if isinstance(x, torch.Tensor):
        return x

    return torch.stack(torch.broadcast_tensors(*x), dim=-1)

def construct_diag(x: torch.Tensor):
    Constructs a diagonal matrix based on batched data. Solution found here:
    Do note that it only considers the last axis.
    :param x: The tensor

    if x.dim() < 1:
        return x
    elif x.shape[-1] < 2:
        return x.unsqueeze(-1)
    elif x.dim() < 2:
        return torch.diag(x)

    b = torch.eye(x.size(-1), device=x.device)
    c = x.unsqueeze(-1).expand(*x.size(), x.size(-1))

    return c * b

def flatten(*args: Iterable[Iterable]) -> Tuple:
    Flattens an array comprised of an arbitrary number of lists. Solution found at:
    :param args: The iterable you wish to flatten.
    :type args: collections.Iterable
    :return: Flattened iterable
    out = list()
    for el in args:
        if isinstance(el, Iterable) and not isinstance(el, (str, bytes, torch.Tensor)):

    return tuple(out)

def unflattify(values: torch.Tensor, shape: torch.Size):
    Unflattifies parameter values.
    :param values: The flattened array of values that are to be unflattified
    :param shape: The shape of the parameter prior

    if len(shape) < 1 or values.shape[1:] == shape:
        return values

    return values.reshape(values.shape[0], *shape)

class TempOverride(object):
    def __init__(self, obj: object, attr: str, new_vals: object):
        Implements a temporary override of attribute of an object.
        :param obj: An object
        :param attr: The attribute to override
        :param new_vals: The new values
        self._obj = obj
        self._attr = attr
        self._new_vals = new_vals
        self._old_vals = None

    def __enter__(self):
        self._old_vals = getattr(self._obj, self._attr)
        setattr(self._obj, self._attr, self._new_vals)

        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        setattr(self._obj, self._attr, self._old_vals)

        return False

class Empirical(Distribution):
    def __init__(self, samples: torch.Tensor):
        Helper class for timeseries without an analytical expression.
        :param samples: The sample
        self.loc = self._samples = samples
        self.scale = torch.zeros_like(samples)

    def sample(self, sample_shape=torch.Size()):
        if sample_shape != self._samples.shape and sample_shape != torch.Size():
            raise ValueError('Current implementation only allows passing an empty size!')

        return self._samples