from __future__ import division, print_function, absolute_import

import os

from distutils.version import LooseVersion

import functools

import numpy as np
from numpy.testing import assert_
import pytest

import scipy.special as sc

__all__ = ['with_special_errors', 'assert_tol_equal', 'assert_func_equal',
           'FuncData']


#------------------------------------------------------------------------------
# Check if a module is present to be used in tests
#------------------------------------------------------------------------------

class MissingModule(object):
    def __init__(self, name):
        self.name = name


def check_version(module, min_ver):
    if type(module) == MissingModule:
        return pytest.mark.skip(reason="{} is not installed".format(module.name))
    return pytest.mark.skipif(LooseVersion(module.__version__) < LooseVersion(min_ver),
                              reason="{} version >= {} required".format(module.__name__, min_ver))


#------------------------------------------------------------------------------
# Enable convergence and loss of precision warnings -- turn off one by one
#------------------------------------------------------------------------------

def with_special_errors(func):
    """
    Enable special function errors (such as underflow, overflow,
    loss of precision, etc.)
    """
    @functools.wraps(func)
    def wrapper(*a, **kw):
        with sc.errstate(all='raise'):
            res = func(*a, **kw)
        return res
    return wrapper


#------------------------------------------------------------------------------
# Comparing function values at many data points at once, with helpful
#------------------------------------------------------------------------------

def assert_tol_equal(a, b, rtol=1e-7, atol=0, err_msg='', verbose=True):
    """Assert that `a` and `b` are equal to tolerance ``atol + rtol*abs(b)``"""
    def compare(x, y):
        return np.allclose(x, y, rtol=rtol, atol=atol)
    a, b = np.asanyarray(a), np.asanyarray(b)
    header = 'Not equal to tolerance rtol=%g, atol=%g' % (rtol, atol)
    np.testing.utils.assert_array_compare(compare, a, b, err_msg=str(err_msg),
                                          verbose=verbose, header=header)


#------------------------------------------------------------------------------
# Comparing function values at many data points at once, with helpful
# error reports
#------------------------------------------------------------------------------

def assert_func_equal(func, results, points, rtol=None, atol=None,
                      param_filter=None, knownfailure=None,
                      vectorized=True, dtype=None, nan_ok=False,
                      ignore_inf_sign=False, distinguish_nan_and_inf=True):
    if hasattr(points, 'next'):
        # it's a generator
        points = list(points)

    points = np.asarray(points)
    if points.ndim == 1:
        points = points[:,None]
    nparams = points.shape[1]

    if hasattr(results, '__name__'):
        # function
        data = points
        result_columns = None
        result_func = results
    else:
        # dataset
        data = np.c_[points, results]
        result_columns = list(range(nparams, data.shape[1]))
        result_func = None

    fdata = FuncData(func, data, list(range(nparams)),
                     result_columns=result_columns, result_func=result_func,
                     rtol=rtol, atol=atol, param_filter=param_filter,
                     knownfailure=knownfailure, nan_ok=nan_ok, vectorized=vectorized,
                     ignore_inf_sign=ignore_inf_sign,
                     distinguish_nan_and_inf=distinguish_nan_and_inf)
    fdata.check()


class FuncData(object):
    """
    Data set for checking a special function.

    Parameters
    ----------
    func : function
        Function to test
    filename : str
        Input file name
    param_columns : int or tuple of ints
        Columns indices in which the parameters to `func` lie.
        Can be imaginary integers to indicate that the parameter
        should be cast to complex.
    result_columns : int or tuple of ints, optional
        Column indices for expected results from `func`.
    result_func : callable, optional
        Function to call to obtain results.
    rtol : float, optional
        Required relative tolerance. Default is 5*eps.
    atol : float, optional
        Required absolute tolerance. Default is 5*tiny.
    param_filter : function, or tuple of functions/Nones, optional
        Filter functions to exclude some parameter ranges.
        If omitted, no filtering is done.
    knownfailure : str, optional
        Known failure error message to raise when the test is run.
        If omitted, no exception is raised.
    nan_ok : bool, optional
        If nan is always an accepted result.
    vectorized : bool, optional
        Whether all functions passed in are vectorized.
    ignore_inf_sign : bool, optional
        Whether to ignore signs of infinities.
        (Doesn't matter for complex-valued functions.)
    distinguish_nan_and_inf : bool, optional
        If True, treat numbers which contain nans or infs as as
        equal. Sets ignore_inf_sign to be True.

    """

    def __init__(self, func, data, param_columns, result_columns=None,
                 result_func=None, rtol=None, atol=None, param_filter=None,
                 knownfailure=None, dataname=None, nan_ok=False, vectorized=True,
                 ignore_inf_sign=False, distinguish_nan_and_inf=True):
        self.func = func
        self.data = data
        self.dataname = dataname
        if not hasattr(param_columns, '__len__'):
            param_columns = (param_columns,)
        self.param_columns = tuple(param_columns)
        if result_columns is not None:
            if not hasattr(result_columns, '__len__'):
                result_columns = (result_columns,)
            self.result_columns = tuple(result_columns)
            if result_func is not None:
                raise ValueError("Only result_func or result_columns should be provided")
        elif result_func is not None:
            self.result_columns = None
        else:
            raise ValueError("Either result_func or result_columns should be provided")
        self.result_func = result_func
        self.rtol = rtol
        self.atol = atol
        if not hasattr(param_filter, '__len__'):
            param_filter = (param_filter,)
        self.param_filter = param_filter
        self.knownfailure = knownfailure
        self.nan_ok = nan_ok
        self.vectorized = vectorized
        self.ignore_inf_sign = ignore_inf_sign
        self.distinguish_nan_and_inf = distinguish_nan_and_inf
        if not self.distinguish_nan_and_inf:
            self.ignore_inf_sign = True

    def get_tolerances(self, dtype):
        if not np.issubdtype(dtype, np.inexact):
            dtype = np.dtype(float)
        info = np.finfo(dtype)
        rtol, atol = self.rtol, self.atol
        if rtol is None:
            rtol = 5*info.eps
        if atol is None:
            atol = 5*info.tiny
        return rtol, atol

    def check(self, data=None, dtype=None):
        """Check the special function against the data."""

        if self.knownfailure:
            pytest.xfail(reason=self.knownfailure)

        if data is None:
            data = self.data

        if dtype is None:
            dtype = data.dtype
        else:
            data = data.astype(dtype)

        rtol, atol = self.get_tolerances(dtype)

        # Apply given filter functions
        if self.param_filter:
            param_mask = np.ones((data.shape[0],), np.bool_)
            for j, filter in zip(self.param_columns, self.param_filter):
                if filter:
                    param_mask &= list(filter(data[:,j]))
            data = data[param_mask]

        # Pick parameters from the correct columns
        params = []
        for j in self.param_columns:
            if np.iscomplexobj(j):
                j = int(j.imag)
                params.append(data[:,j].astype(complex))
            else:
                params.append(data[:,j])

        # Helper for evaluating results
        def eval_func_at_params(func, skip_mask=None):
            if self.vectorized:
                got = func(*params)
            else:
                got = []
                for j in range(len(params[0])):
                    if skip_mask is not None and skip_mask[j]:
                        got.append(np.nan)
                        continue
                    got.append(func(*tuple([params[i][j] for i in range(len(params))])))
                got = np.asarray(got)
            if not isinstance(got, tuple):
                got = (got,)
            return got

        # Evaluate function to be tested
        got = eval_func_at_params(self.func)

        # Grab the correct results
        if self.result_columns is not None:
            # Correct results passed in with the data
            wanted = tuple([data[:,icol] for icol in self.result_columns])
        else:
            # Function producing correct results passed in
            skip_mask = None
            if self.nan_ok and len(got) == 1:
                # Don't spend time evaluating what doesn't need to be evaluated
                skip_mask = np.isnan(got[0])
            wanted = eval_func_at_params(self.result_func, skip_mask=skip_mask)

        # Check the validity of each output returned
        assert_(len(got) == len(wanted))

        for output_num, (x, y) in enumerate(zip(got, wanted)):
            if np.issubdtype(x.dtype, np.complexfloating) or self.ignore_inf_sign:
                pinf_x = np.isinf(x)
                pinf_y = np.isinf(y)
                minf_x = np.isinf(x)
                minf_y = np.isinf(y)
            else:
                pinf_x = np.isposinf(x)
                pinf_y = np.isposinf(y)
                minf_x = np.isneginf(x)
                minf_y = np.isneginf(y)
            nan_x = np.isnan(x)
            nan_y = np.isnan(y)

            olderr = np.seterr(all='ignore')
            try:
                abs_y = np.absolute(y)
                abs_y[~np.isfinite(abs_y)] = 0
                diff = np.absolute(x - y)
                diff[~np.isfinite(diff)] = 0

                rdiff = diff / np.absolute(y)
                rdiff[~np.isfinite(rdiff)] = 0
            finally:
                np.seterr(**olderr)

            tol_mask = (diff <= atol + rtol*abs_y)
            pinf_mask = (pinf_x == pinf_y)
            minf_mask = (minf_x == minf_y)

            nan_mask = (nan_x == nan_y)

            bad_j = ~(tol_mask & pinf_mask & minf_mask & nan_mask)

            point_count = bad_j.size
            if self.nan_ok:
                bad_j &= ~nan_x
                bad_j &= ~nan_y
                point_count -= (nan_x | nan_y).sum()

            if not self.distinguish_nan_and_inf and not self.nan_ok:
                # If nan's are okay we've already covered all these cases
                inf_x = np.isinf(x)
                inf_y = np.isinf(y)
                both_nonfinite = (inf_x & nan_y) | (nan_x & inf_y)
                bad_j &= ~both_nonfinite
                point_count -= both_nonfinite.sum()

            if np.any(bad_j):
                # Some bad results: inform what, where, and how bad
                msg = [""]
                msg.append("Max |adiff|: %g" % diff.max())
                msg.append("Max |rdiff|: %g" % rdiff.max())
                msg.append("Bad results (%d out of %d) for the following points (in output %d):"
                           % (np.sum(bad_j), point_count, output_num,))
                for j in np.where(bad_j)[0]:
                    j = int(j)
                    fmt = lambda x: "%30s" % np.array2string(x[j], precision=18)
                    a = "  ".join(map(fmt, params))
                    b = "  ".join(map(fmt, got))
                    c = "  ".join(map(fmt, wanted))
                    d = fmt(rdiff)
                    msg.append("%s => %s != %s  (rdiff %s)" % (a, b, c, d))
                assert_(False, "\n".join(msg))

    def __repr__(self):
        """Pretty-printing, esp. for Nose output"""
        if np.any(list(map(np.iscomplexobj, self.param_columns))):
            is_complex = " (complex)"
        else:
            is_complex = ""
        if self.dataname:
            return "<Data for %s%s: %s>" % (self.func.__name__, is_complex,
                                            os.path.basename(self.dataname))
        else:
            return "<Data for %s%s>" % (self.func.__name__, is_complex)