from __future__ import division, print_function, absolute_import
import numpy as np
import scipy.sparse as spc
from scipy.sparse.linalg import LinearOperator
from ._numdiff import approx_derivative
from warnings import warn


__all__ = ['NonlinearConstraint',
           'LinearConstraint',
           'BoxConstraint']


class NonlinearConstraint:
    """Nonlinear constraint

    Parameters
    ----------
    fun : callable
        The function defining the constraint.

            fun(x) -> array_like, shape (m,)

        where x is a (n,) ndarray and ``m``
        is the number of constraints.
    kind : {str, tuple}
        Specifies the type of contraint. Options for this
        parameter are:

            - ('interval', lb, ub) for a constraint of the type:
                lb <= fun(x) <= ub
            - ('greater', lb) for a constraint of the type:
                fun(x) >= lb
            - ('less', ub) for a constraint of the type:
                fun(x) <= ub
            - ('equals', c) for a constraint of the type:
                fun(x) == c
            - ('greater',) for a constraint of the type:
                fun(x) >= 0
            - ('less',) for a constraint of the type:
                fun(x) <= 0
            - ('equals',) for a constraint of the type:
                fun(x) == 0

        where ``lb``,  ``ub`` and ``c`` are (m,) ndarrays or
        scalar values. In the latter case, the same value
        will be repeated for all the constraints.
    jac : callable
        Jacobian Matrix:

            jac(x) -> {ndarray, sparse matrix}, shape (m, n)

        where x is a (n,) ndarray.
    hess : {callable, '2-point', '3-point', 'cs', None}
        Method for computing the Hessian matrix. The keywords
        select a finite difference scheme for numerical
        estimation. The scheme '3-point' is more accurate, but requires
        twice as much operations compared to '2-point' (default). The
        scheme 'cs' uses complex steps, and while potentially the most
        accurate, it is applicable only when `fun` correctly handles
        complex inputs and can be analytically continued to the complex
        plane. If it is a callable, it should return the 
        Hessian matrix of `dot(fun, v)`:

            hess(x, v) -> {LinearOperator, sparse matrix, ndarray}, shape (n, n)

        where x is a (n,) ndarray and v is a (m,) ndarray. When ``hess``
        is None it considers the hessian is an matrix filled with zeros.
    enforce_feasibility : {list of boolean, boolean}, optional
        Specify if the constraint must be feasible along the iterations.
        If ``True``  all the iterates generated by the optimization
        algorithm need to be feasible in respect to a constraint. If ``False``
        this is not needed. A list can be passed to specify element-wise
        each constraints needs to stay feasible along the iterations and
        each does not. Alternatively, a single boolean can be used to
        specify the feasibility required of all constraints. By default it
        is False.
    """
    def __init__(self, fun, kind, jac, hess='2-point', enforce_feasibility=False):
        self._fun = fun
        self.kind = kind
        self._jac = jac
        self._hess = hess
        self.enforce_feasibility = enforce_feasibility
        self.isinitialized = False

    def evaluate_and_initialize(self, x0, sparse_jacobian=None):
        x0 = np.atleast_1d(x0).astype(float)
        f0 = np.atleast_1d(self._fun(x0))
        v0 = np.zeros_like(f0)
        J0 = self._jac(x0)

        def fun_wrapped(x):
            return np.atleast_1d(self._fun(x))

        if sparse_jacobian or (sparse_jacobian is None and spc.issparse(J0)):
            def jac_wrapped(x):
                return spc.csr_matrix(self._jac(x))
            self.sparse_jacobian = True

            self.J0 = spc.csr_matrix(J0)

        else:
            def jac_wrapped(x):
                J = self._jac(x)
                if spc.issparse(J):
                    return J.toarray()
                else:
                    return np.atleast_2d(J)
            self.sparse_jacobian = False

            if spc.issparse(J0):
                self.J0 = J0.toarray()
            else:
                self.J0 = np.atleast_2d(J0)

        if callable(self._hess):
            H0 = self._hess(x0, v0)

            if spc.issparse(H0):
                H0 = spc.csr_matrix(H0)
                
                def hess_wrapped(x, v):
                    return spc.csr_matrix(self._hess(x, v))

            elif isinstance(H0, LinearOperator):
                def hess_wrapped(x, v):
                    return self._hess(x, v)

            else:
                H0 = np.atleast_2d(np.asarray(H0))

                def hess_wrapped(x, v):
                    return np.atleast_2d(np.asarray(self._hess(x, v)))

        elif self._hess in ('2-point', '3-point', 'cs'):
            approx_method = self._hess

            def jac_dot_v(x, v):
                J = jac_wrapped(x)
                return J.T.dot(v)

            def hess_wrapped(x, v):
                return approx_derivative(jac_dot_v, x, approx_method,
                                         as_linear_operator=True,
                                         args=(v,))

        else:
            hess_wrapped = self._hess

        self.fun = fun_wrapped
        self.jac = jac_wrapped
        self.hess = hess_wrapped
        self.x0 = x0
        self.f0 = f0
        self.n = x0.size
        self.m = f0.size
        self.kind = _check_kind(self.kind, self.m)
        self.enforce_feasibility \
            = _check_enforce_feasibility(self.enforce_feasibility, self.m)
        if not _is_feasible(self.kind, self.enforce_feasibility, f0):
            raise ValueError("Unfeasible initial point. "
                             "Either set ``enforce_feasibility=False`` or "
                             "choose a new feasible initial point ``x0``.")

        self.isinitialized = True
        return x0


class LinearConstraint:
    """Linear constraint.

    Parameters
    ----------
    A : {ndarray, sparse matrix}, shape (m, n)
        Matrix for the linear constraint.
    kind : {str, tuple}
        Specifies the type of contraint. Options for this
        parameter are:

            - ('interval', lb, ub) for a constraint of the type:
                lb <= A x <= ub
            - ('greater', lb) for a constraint of the type:
                A x >= lb
            - ('less', ub) for a constraint of the type:
                A x <= ub
            - ('equals', c) for a constraint of the type:
                A x == c
            - ('greater',) for a constraint of the type:
                A x >= 0
            - ('less',) for a constraint of the type:
                A x <= 0
            - ('equals',) for a constraint of the type:
                A x == 0

        where ``lb``,  ``ub`` and ``c`` are (m,) ndarrays or
        scalar values. In the latter case, the same value
        will be repeated for all the constraints.
    enforce_feasibility : {list of boolean, boolean}, optional
        Specify if the constraint must be feasible along the iterations.
        If ``True``  all the iterates generated by the optimization
        algorithm need to be feasible in respect to a constraint. If ``False``
        this is not needed. A list can be passed to specify element-wise
        each constraints needs to stay feasible along the iterations and
        each does not. Alternatively, a single boolean can be used to
        specify the feasibility required of all constraints. By default it
        is False.
    """
    def __init__(self, A, kind, enforce_feasibility=False):
        self.A = A
        self.kind = kind
        self.enforce_feasibility = enforce_feasibility
        self.isinitialized = False

    def evaluate_and_initialize(self, x0, sparse_jacobian=None):
        if sparse_jacobian or (sparse_jacobian is None
                               and spc.issparse(self.A)):
            self.A = spc.csr_matrix(self.A)
            self.sparse_jacobian = True
        else:
            if spc.issparse(self.A):
                self.A = self.A.toarray()
            else:
                self.A = np.atleast_2d(self.A)
            self.sparse_jacobian = False

        x0 = np.atleast_1d(x0).astype(float)
        f0 = self.A.dot(x0)
        J0 = self.A

        self.x0 = x0
        self.f0 = f0
        self.J0 = J0
        self.n = x0.size
        self.m = f0.size
        self.kind = _check_kind(self.kind, self.m)
        self.enforce_feasibility \
            = _check_enforce_feasibility(self.enforce_feasibility, self.m)
        if not _is_feasible(self.kind, self.enforce_feasibility, f0):
            raise ValueError("Unfeasible initial point. "
                             "Either set ``enforce_feasibility=False`` or "
                             "choose a new feasible initial point ``x0``.")

        self.isinitialized = True
        return x0

    def to_nonlinear(self):
        if not self.isinitialized:
            raise RuntimeError("Trying to convert uninitialized constraint.")

        def fun(x):
            return self.A.dot(x)

        def jac(x):
            return self.A

        # Build Constraints
        nonlinear = NonlinearConstraint(fun, self.kind, jac, None,
                                        self.enforce_feasibility)
        nonlinear.isinitialized = True
        nonlinear.m = self.m
        nonlinear.n = self.n
        nonlinear.sparse_jacobian = self.sparse_jacobian
        nonlinear.fun = fun
        nonlinear.jac = jac
        nonlinear.hess = None
        nonlinear.x0 = self.x0
        nonlinear.f0 = self.f0
        nonlinear.J0 = self.J0
        return nonlinear


class BoxConstraint:
    """Box constraint.

    Parameters
    ----------
    kind : tuple
        Specifies the type of contraint. Options for this
        parameter are:

            - ('interval', lb, ub) for a constraint of the type:
                lb <= A x <= ub
            - ('greater', lb) for a constraint of the type:
                A x >= lb
            - ('less', ub) for a constraint of the type:
                A x <= ub
            - ('equals', c) for a constraint of the type:
                A x == c
            - ('greater',) for a constraint of the type:
                A x >= 0
            - ('less',) for a constraint of the type:
                A x <= 0
            - ('equals',) for a constraint of the type:
                A x == 0

        where ``lb``,  ``ub`` and ``c`` are (m,) ndarrays or
        scalar values. In the latter case, the same value
        will be repeated for all the constraints.
    enforce_feasibility : {list of boolean, boolean}, optional
        Specify if the constraint must be feasible along the iterations.
        If ``True``  all the iterates generated by the optimization
        algorithm need to be feasible in respect to a constraint. If ``False``
        this is not needed. A list can be passed to specify element-wise
        each constraints needs to stay feasible along the iterations and
        each does not. Alternatively, a single boolean can be used to
        specify the feasibility required of all constraints. By default it
        is False.
    """
    def __init__(self, kind, enforce_feasibility=False):
        self.kind = kind
        self.enforce_feasibility = enforce_feasibility
        self.isinitialized = False

    def evaluate_and_initialize(self, x0, sparse_jacobian=None):
        x0 = np.atleast_1d(x0).astype(float)
        f0 = x0
        self.n = x0.size
        self.m = f0.size
        if sparse_jacobian or sparse_jacobian is None:
            J0 = spc.eye(self.n).tocsr()
            self.sparse_jacobian = True
        else:
            J0 = np.eye(self.n)
            self.sparse_jacobian = False

        self.J0 = J0
        self.kind = _check_kind(self.kind, self.m)
        self.enforce_feasibility \
            = _check_enforce_feasibility(self.enforce_feasibility, self.m)
        self.isinitialized = True
        if not _is_feasible(self.kind, self.enforce_feasibility, f0):
            warn("The initial point was changed in order "
                 "to stay inside box constraints.")
            x0_new = _reinforce_box_constraint(self.kind,
                                               self.enforce_feasibility,
                                               x0)
            self.x0 = x0_new
            self.f0 = x0_new
            return x0_new
        else:
            self.x0 = x0
            self.f0 = f0
            return x0

    def to_linear(self):
        if not self.isinitialized:
            raise RuntimeError("Trying to convert uninitialized constraint.")
        # Build Constraints
        linear = LinearConstraint(self.J0, self.kind,
                                  self.enforce_feasibility)
        linear.isinitialized = True
        linear.m = self.m
        linear.n = self.n
        linear.sparse_jacobian = self.sparse_jacobian
        linear.x0 = self.x0
        linear.f0 = self.f0
        linear.J0 = self.J0
        return linear

    def to_nonlinear(self):
        if not self.isinitialized:
            raise RuntimeError("Trying to convert uninitialized constraint.")
        return self.to_linear().to_nonlinear()


# ************************************************************ #
# **********           Auxiliar Functions           ********** #
# ************************************************************ #
def _check_kind(kind, m):
    if not isinstance(kind, (tuple, list, str)):
        raise ValueError("The parameter `kind` should be a tuple, "
                         " a list, or a string.")
    if isinstance(kind, str):
        kind = (kind,)
    if len(kind) == 0:
        raise ValueError("The parameter `kind` should not be empty.")

    n_args = len(kind)
    keyword = kind[0]
    if keyword not in ("greater", "less", "equals", "interval"):
        raise ValueError("Keyword `%s` not available." % keyword)
    if n_args in (1, 2) and keyword not in ("greater", "less", "equals") \
       or n_args == 3 and keyword not in ("interval"):
        raise ValueError("Invalid `kind` format.")
    if n_args == 1:
        kind = (keyword, 0)

    if keyword in ("greater", "less", "equals"):
        c = np.asarray(kind[1], dtype=float)
        if np.size(c) not in (1, m):
            if keyword == "greater":
                raise ValueError("`lb` has the wrong dimension.")
            if keyword == "less":
                raise ValueError("`ub` has the wrong dimension.")
            if keyword == "equals":
                raise ValueError("`c` has the wrong dimension.")
        c = np.resize(c, m)
        return (keyword, c)
    elif keyword == "interval":
        lb = np.asarray(kind[1], dtype=float)
        if np.size(lb) not in (1, m):
            raise ValueError("`lb` has the wrong dimension.")
        lb = np.resize(lb, m)
        ub = np.asarray(kind[2], dtype=float)
        if np.size(ub) not in (1, m):
            raise ValueError("`ub` has the wrong dimension.")
        ub = np.resize(ub, m)
        if (lb > ub).any():
            raise ValueError("lb[i] > ub[i].")
        return (keyword, lb, ub)


def _check_enforce_feasibility(enforce_feasibility, m):
    if isinstance(enforce_feasibility, bool):
        enforce_feasibility = np.full(m,
                                      enforce_feasibility,
                                      dtype=bool)
    else:
        enforce_feasibility = np.array(enforce_feasibility,
                                       dtype=bool)

        if enforce_feasibility.size != m:
            raise ValueError("The parameter 'enforce_feasibility' "
                             "has the wrong number of elements.")
    return enforce_feasibility


def _is_feasible(kind, enforce_feasibility, f0):
    keyword = kind[0]
    if keyword == "equals":
        lb = np.asarray(kind[1], dtype=float)
        ub = np.asarray(kind[1], dtype=float)
    elif keyword == "greater":
        lb = np.asarray(kind[1], dtype=float)
        ub = np.full_like(lb, np.inf, dtype=float)
    elif keyword == "less":
        ub = np.asarray(kind[1], dtype=float)
        lb = np.full_like(ub, -np.inf, dtype=float)
    elif keyword == "interval":
        lb = np.asarray(kind[1], dtype=float)
        ub = np.asarray(kind[2], dtype=float)
    else:
        raise RuntimeError("Never be here.")

    return ((lb[enforce_feasibility] <= f0[enforce_feasibility]).all()
            and (f0[enforce_feasibility] <= ub[enforce_feasibility]).all())


def _reinforce_box_constraint(kind, enforce_feasibility, x0,
                              relative_tolerance=0.01,
                              absolute_tolerance=0.01):
        """Reinforce box constraint"""
        x0 = np.copy(np.asarray(x0, dtype=float))
        keyword = kind[0]
        if keyword == "greater":
            lb = np.asarray(kind[1], dtype=float)
            ub = np.full_like(lb, np.inf, dtype=float)
        elif keyword == "less":
            ub = np.asarray(kind[1], dtype=float)
            lb = np.full_like(ub, -np.inf, dtype=float)
        elif keyword == "interval":
            lb = np.asarray(kind[1], dtype=float)
            ub = np.asarray(kind[2], dtype=float)

        x0_new = np.copy(x0)
        for i in range(np.size(x0)):
            if enforce_feasibility[i]:
                if not np.isinf(lb[i]):
                    lower_bound = min(lb[i]+absolute_tolerance,
                                      lb[i]+relative_tolerance*(ub[i]-lb[i]))
                    x0_new[i] = max(x0_new[i], lower_bound)
                if not np.isinf(ub[i]):
                    upper_bound = max(ub[i]-absolute_tolerance,
                                      ub[i]-relative_tolerance*(ub[i]-lb[i]))
                    x0_new[i] = min(x0_new[i], upper_bound)
        return x0_new