from __future__ import division, print_function, absolute_import
import numpy as np
from scipy.interpolate import interp1d
try:
    from collections.abc import Iterable
except ImportError:
    from collections import Iterable
from copy import deepcopy as make_copy
from scipy.ndimage import binary_dilation
from .config import BIG_NUMBER, MIN_LOG, MIN_INTEGRATION_PEAK, TINY_NUMBER

class Distribution(object):
    """
    Class to implement the probability distribution. This class wraps the scipy
    linear interpolation object, and implements some additional operations,
    needed to manipulate distributions for tree nodes positions, branch lengths,
    etc.
    This class is callable, so it can be treated similarly to the scipy interpolation
    object.
    """

    @staticmethod
    def calc_fwhm(distribution, is_neg_log=True):
        """
        Assess the width of the probability distribution. This returns
        full-width-half-max
        """

        if isinstance(distribution, interp1d):

            if is_neg_log:
                ymin = distribution.y.min()
                log_prob = distribution.y-ymin
            else:
                log_prob = -np.log(distribution.y)
                log_prob -= log_prob.min()

            xvals = distribution.x

        elif isinstance(distribution, Distribution):
            # Distribution always stores neg log-prob with the peak value subtracted
            xvals = distribution._func.x
            log_prob = distribution._func.y
        else:
            raise TypeError("Error in computing the FWHM for the distribution. "
                " The input should be either Distribution or interpolation object");

        L = xvals.shape[0]
        # 0.69... is log(2), there is always one value for which this is true since
        # the minimum is subtracted
        tmp = np.where(log_prob < 0.693147)[0]
        x_l, x_u = tmp[0], tmp[-1]
        if L < 2:
            print ("Not enough points to compute FWHM: returning zero")
            return min(TINY_NUMBER, distribution.xmax - distribution.xmin)
        else:
            # need to guard against out-of-bounds errors
            return max(TINY_NUMBER, xvals[min(x_u+1,L-1)] - xvals[max(0,x_l-1)])


    @classmethod
    def delta_function(cls, x_pos, weight=1., min_width=MIN_INTEGRATION_PEAK):
        """
        Create delta function distribution.
        """

        distribution = cls(x_pos,0.,is_log=True, min_width=min_width)
        distribution.weight  = weight
        return distribution


    @classmethod
    def shifted_x(cls, dist, delta_x):
        return Distribution(dist.x+delta_x, dist.y, kind=dist.kind)


    @staticmethod
    def multiply(dists):
        '''
        multiplies a list of Distribution objects
        '''
        if  not all([isinstance(k, Distribution) for k in dists]):
            raise NotImplementedError("Can only multiply Distribution objects")

        n_delta = np.sum([k.is_delta for k in dists])
        min_width = np.max([k.min_width for k in dists])
        if n_delta>1:
            raise ArithmeticError("Cannot multiply more than one delta functions!")
        elif n_delta==1:
            delta_dist_ii = np.where([k.is_delta for k in dists])[0][0]
            delta_dist = dists[delta_dist_ii]
            new_xpos = delta_dist.peak_pos
            new_weight  = np.prod([k.prob(new_xpos) for k in dists if k!=delta_dist_ii]) * delta_dist.weight
            res = Distribution.delta_function(new_xpos, weight = new_weight,min_width=min_width)
        else:
            new_xmin = np.max([k.xmin for k in dists])
            new_xmax = np.min([k.xmax for k in dists])

            x_vals = np.unique(np.concatenate([k.x for k in dists]))
            x_vals = x_vals[(x_vals>new_xmin-TINY_NUMBER)&(x_vals<new_xmax+TINY_NUMBER)]
            y_vals = np.sum([k.__call__(x_vals) for k in dists], axis=0)
            peak = y_vals.min()
            ind = (y_vals-peak)<BIG_NUMBER/1000
            n_points = ind.sum()
            if n_points == 0:
                print ("ERROR in distribution multiplication: Distributions do not overlap")
                x_vals = [0,1]
                y_vals = [BIG_NUMBER,BIG_NUMBER]
                res = Distribution(x_vals, y_vals, is_log=True,
                                   min_width=min_width, kind='linear')
            elif n_points == 1:
                res = Distribution.delta_function(x_vals[0])
            else:
                res = Distribution(x_vals[ind], y_vals[ind], is_log=True,
                                   min_width=min_width, kind='linear', assume_sorted=True)

        return res


    def __init__(self, x, y, is_log=True, min_width = MIN_INTEGRATION_PEAK,
                 kind='linear', assume_sorted=False):

        """
        Create Distribution instance
        """

        self.min_width = min_width
        if isinstance(x, Iterable) and isinstance (y, Iterable):

            self._delta = False # NOTE in classmethod this value is set explicitly to True.
            # first, prepare x, y values
            if assume_sorted:
                xvals, yvals = x,y
            else:
                xvals, yvals = np.array(sorted(zip(x,y))).T
            if not is_log:
                yvals = -np.log(yvals)
            # just for safety
            yvals[np.isnan(yvals)] = BIG_NUMBER
            # set the properties
            self._kind=kind
            # remember range
            self._xmin, self._xmax = xvals[0], xvals[-1]
            self._support = self._xmax - self._xmin
            # extract peak
            self._peak_idx = yvals.argmin()
            self._peak_val = yvals.min()
            self._peak_pos = xvals[self._peak_idx]
            yvals -= self._peak_val
            self._ymax = yvals.max()
            # store the interpolation object
            self._func= interp1d(xvals, yvals, kind=kind, fill_value=BIG_NUMBER,
                                 bounds_error=False, assume_sorted=True)
            self._fwhm = Distribution.calc_fwhm(self)

        elif np.isscalar(x):
            assert (np.isscalar(y) or y is None)
            self._delta = True
            self._peak_pos = x
            self._fwhm = 0
            if y is None:
                self._peak_val = np.inf
            else:
                self._peak_val = y

            self._xmin, self._xmax = x, x
            self._support = 0.
            self._func = lambda x : (x==self.peak_pos)*self.peak_val
        else:
            raise TypeError("Cannot create Distribution: "
                "Input arguments should be scalars or iterables!")


    @property
    def is_delta(self):
        return self._delta

    @property
    def kind(self):
        return self._kind

    @property
    def peak_val(self):
        return self._peak_val

    @property
    def peak_pos(self):
        return self._peak_pos

    @property
    def peak_idx(self):
        return self._peak_idx

    @property
    def support(self):
        return self._support

    @property
    def fwhm(self):
        return self._fwhm

    @property
    def x(self):
        if self.is_delta:
            return [self._peak_pos]
        else:
            return self._func.x

    @property
    def y(self):
        if self.is_delta:
            print("THIS SHOULDN'T BE CALLED ON A DELTA FUNCTION")
            return [self.weight]
        else:
            return self._peak_val + self._func.y

    @property
    def xmin(self):
        return self._xmin

    @property
    def xmax(self):
        return self._xmax


    def __call__(self, x):

        if isinstance(x, Iterable):
            valid_idxs = (x > self._xmin-TINY_NUMBER) & (x < self._xmax+TINY_NUMBER)
            res = np.ones_like (x, dtype=float) * (BIG_NUMBER+self.peak_val)
            tmp_x = np.copy(x[valid_idxs])
            tmp_x[tmp_x<self._xmin+TINY_NUMBER] = self._xmin+TINY_NUMBER
            tmp_x[tmp_x>self._xmax-TINY_NUMBER] = self._xmax-TINY_NUMBER
            res[valid_idxs] = self._peak_val + self._func(tmp_x)
            return res

        elif np.isreal(x):
            if x < self._xmin or x > self._xmax:
                return BIG_NUMBER+self.peak_val
            # x is within interpolation range
            elif self._delta == True:
                return self._peak_val
            else:
                return self._peak_val + self._func(x)
        else:
            raise TypeError("Wrong type: should be float or array")


    def __mul__(self, other):
        return Distribution.multiply((self, other))


    def _adjust_grid(self, rel_tol=0.01, yc=10):
        updated = True
        n_iter=0
        while len(self.y)>200 and updated and n_iter<5:
            interp_err = 2*self.y[1:-1] - self.y[2:] - self.y[:-2]
            ind = np.ones_like(self.y, dtype=bool)
            dy = self.y-self.peak_val
            prune = interp_err[::2] > rel_tol*(1+ (dy[1:-1:2]/yc)**4)
            ind[1:-1:2] = prune
            if np.mean(prune)<1.0:
                self._func.y = self._func.y[ind]
                self._func.x = self._func.x[ind]
                updated=True
                n_iter+=1
            else:
                updated=False
                n_iter+=1

        self._peak_idx = self.__call__(self._func.x).argmin()
        self._peak_pos = self._func.x[self._peak_idx]
        self._peak_val = self.__call__(self.peak_pos)


    def prob(self,x):
        return np.exp(-1 * self.__call__(x))

    def prob_relative(self,x):
        return np.exp(-1 * (self.__call__(x)-self.peak_val))

    def x_rescale(self, factor):
        self._func.x*=factor
        self._peak_pos*=factor
        if factor>=0:
            self._xmin*=factor
            self._xmax*=factor
        else:
            tmp = self.xmin
            self._xmin = factor*self.xmax
            self._xmax = factor*tmp
            self._func.x = self._func.x[::-1]
            self._func.y = self._func.y[::-1]


    def integrate(self, return_log=False ,**kwargs):
        if self.is_delta:
            return self.weight
        else:
            integral_result = self.integrate_simpson(**kwargs)
            if return_log:
                if integral_result==0:
                    return -self.peak_val - BIG_NUMBER
                else:
                    return -self.peak_val + max(-BIG_NUMBER, np.log(integral_result))
            else:
                return np.exp(-self.peak_val)*integral_result

    def integrate_trapez(self, a=None, b=None,n=None):
        mult = 0.5
        if a>b:
            b,a = a,b
            mult=-0.5

        x = np.linspace(a,b,n)
        dx = np.diff(x)
        y = self.prob_relative(x)
        return mult*np.sum(dx*(y[:-1] + y[1:]))


    def integrate_simpson(self, a=None,b=None,n=None):
        if n % 2 == 0:
            n += 1
        mult = 1.0/6
        dpeak = max(10*self.fwhm, self.min_width)
        threshold = np.array([a,self.peak_pos-dpeak, self.peak_pos+dpeak,b])
        threshold = threshold[(threshold>=a)&(threshold<=b)]
        threshold.sort()
        res = []
        for lw, up in zip(threshold[:-1], threshold[1:]):
            x = np.linspace(lw,up,n)
            dx = np.diff(x[::2])
            y = self.prob_relative(x)
            res.append(mult*(dx[0]*y[0]+ np.sum(4*dx*y[1:-1:2])
                    + np.sum((dx[:-1]+dx[1:])*y[2:-1:2]) + dx[-1]*y[-1]))

        return np.sum(res)

if __name__=="__main__":
    # code used for debugging and development
    from matplotlib import pyplot as plt
    plt.ion()

    x = [-1e-10,  0.,    1.,  2.,    2.+1e-10]
    y = [  1e-10, 1e-10, 10., 1e-10, 1e-10]
    d1 = Distribution(x, y,is_log=False)

    def f(x):
        return (x**2-5)**2 #(x-5)**2+np.abs(x)**3
    def g(x):
        return (x-4)**2*(x**(1.0/3)-5)**2

    # measure interpolation accuracy
    plot=False
    error = {}
    for kind in ['linear', 'quadratic', 'cubic', 'Q']:
        error[kind]=[[],[]]
    npoints = [11,21] #,31,41, 51,75,101]
    for ex, func in [[0,f],[1,g]]:
        for npoint in npoints:
            if ex==0:
                xnew = np.linspace(-5,15,1000)
                x = np.linspace(0,10,npoint)
            elif ex==1:
                xnew = np.linspace(0,150,1000)
                x = np.linspace(0,9.0,npoint)**3

            if plot:
                plt.figure()
                plt.plot(x, np.exp(-func(x)),'-o', label = 'data')
                plt.plot(xnew, np.exp(-func(xnew)),'-',label='true')
            for kind in ['linear', 'quadratic', 'cubic']:
                try:
                    dist = Distribution(x, func(x), kind=kind, is_log=True)
                    if plot: plt.plot(xnew, dist.prob(xnew), label=kind)
                    E = np.mean((np.exp(-func(xnew))-dist.prob(xnew))[(xnew>dist.xmin) & (xnew<dist.xmax)]**2)
                    print(kind,npoint, E)
                except:
                    E=np.nan
                error[kind][ex].append(E)
            try:
                distQ = quadratic_interpolator(x, func(x))
                if plot: plt.plot(xnew[(xnew>dist.xmin) & (xnew<dist.xmax)], np.exp(-distQ(xnew))[(xnew>dist.xmin) & (xnew<dist.xmax)], label='Q')
                E = np.mean((np.exp(-func(xnew))-np.exp(-distQ(xnew)))[(xnew>dist.xmin) & (xnew<dist.xmax)]**2)
                print('Q',npoint, E)
            except:
                E=np.nan
            error['Q'][ex].append(E)
            if plot:
                plt.yscale('log')
                plt.legend()

    for ex in [0,1]:
        plt.figure()
        for k in error:
            plt.plot(npoints, error[k][ex],'-o', label=k)
        plt.yscale('log')
        plt.legend()

    # measure integration accuracy
    integration_error = {'trapez':[], 'simpson':[], 'piecewise':[]}
    npoints = [11,21,31, 41, 51,75,101, 201, 501, 1001]
    xnew = np.linspace(-5,15,1000)
    x = np.linspace(0,10,3000)
    dist = Distribution(x, g(x), kind='linear', is_log=True)
    for npoint in npoints:
        integration_error['trapez'].append(dist.integrate_trapez(0,10,npoint))
        integration_error['simpson'].append(dist.integrate_simpson(0,10,npoint))
        xtmp = np.linspace(0,10,min(100,npoint))
        disttmp = Distribution(xtmp, g(xtmp), kind='cubic', is_log=True)

    plt.figure()
    base_line = integration_error['simpson'][-1]
    plt.plot(npoints, np.abs(integration_error['trapez']-base_line), label='trapez')
    plt.plot(npoints, np.abs(integration_error['simpson']-base_line), label='simpson')
    plt.xlabel('npoints')
    plt.xscale('log')
    plt.yscale('log')
    plt.legend()