# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, List, Callable
from math import exp, sqrt, tanh
import numpy as np
from nevergrad.parametrization import discretization
from nevergrad.common.decorators import Registry

registry: Registry[Callable[[np.ndarray], float]] = Registry()

def onemax(x: List[int]) -> float:
    """onemax(x) is the most classical case of discrete functions, adapted to minimization.

    It is originally designed for lists of bits. It just counts the number of 1,
    and returns len(x) - number of ones..
    It also works in the continuous case but in that cases discretizes the
    input domain by ]0.5,1.5] --> 1 and 0 everywhere else.
    return len(x) - sum(1 if int(round(w)) == 1 else 0 for w in x)

def leadingones(x: List[int]) -> float:
    """leadingones is the second most classical discrete function, adapted for minimization.

    Returns len(x) - number of initial 1. I.e.
    leadingones([0 1 1 1]) = 4,
    leadingones([1 1 1 1]) = 0,
    leadingones([1 0 0 0]) = 1.
    for i, x_ in enumerate(list(x)):
        if int(round(x_)) != 1:
            return len(x) - i
    return 0

def jump(x: List[int]) -> float:  # TODO: docstring?
    """There exists variants of jump functions; we are in minimization.

    The principle of a jump function is that local descent does not succeed.
    Jumps are necessary.
    n = len(x)
    m = n // 4
    o = n - onemax(x)
    if o == n or o <= n - m:
        return n - m - o
    return o  # Deceptive part.

def _styblinksitang(x: np.ndarray, noise: float) -> float:
    """Classical function for testing noisy optimization."""
    x2 = x ** 2
    val = x2.dot(x2) + np.sum(5 * x - 16 * x2)
    # return a positive value for maximization
    return float(39.16599 * len(x) + 0.5 * val + noise * np.random.normal(size=val.shape))

class DelayedSphere:

    def __call__(self, x: np.ndarray) -> float:
        return float(np.sum(x ** 2))

    def compute_pseudotime(self, input_parameter: Any, value: float) -> float:  # pylint: disable=unused-argument
        x = input_parameter[0][0]
        return float(abs(1.0 / x[0]) / 1000.0) if x[0] != 0.0 else 0.0


def sphere(x: np.ndarray) -> float:
    """The most classical continuous optimization testbed.

    If you do not solve that one then you have a bug."""
    assert x.ndim == 1
    return float(x.dot(x))

def sphere1(x: np.ndarray) -> float:
    """Translated sphere function."""
    return sphere(x - 1.0)

def sphere2(x: np.ndarray) -> float:
    """A bit more translated sphere function."""
    return sphere(x - 2.0)

def sphere4(x: np.ndarray) -> float:
    """Even more translated sphere function."""
    return sphere(x - 4.0)

def maxdeceptive(x: np.ndarray) -> float:
    dec = 3 * x ** 2 - (2 / (3 ** (x - 2) ** 2 + 0.1))
    return float(np.max(dec))

def sumdeceptive(x: np.ndarray) -> float:
    dec = 3 * x ** 2 - (2 / (3 ** (x - 2) ** 2 + 0.1))
    return float(np.sum(dec))

def altcigar(x: np.ndarray) -> float:
    """Similar to cigar, but variables in inverse order.

    E.g. for pointing out algorithms not invariant to the order of variables."""
    return float(x[-1]) ** 2 + 1000000.0 * sphere(x[:-1])

def discus(x: np.ndarray) -> float:
    """Only one variable is very penalized."""
    return sphere(x[1:]) + 1000000.0 * float(x[0]) ** 2

def cigar(x: np.ndarray) -> float:
    """Classical example of ill conditioned function.

    The other classical example is ellipsoid.
    return float(x[0]) ** 2 + 1000000.0 * sphere(x[1:])

def bentcigar(x: np.ndarray) -> float:
    """Classical example of ill conditioned function, but bent."""
    y = np.asarray([x[i] ** (1 + .5 * np.sqrt(x[i]) * (i - 1) / (len(x) - 1)) if x[i] > 0. else x[i] for i in range(len(x))])
    return float(y[0]) ** 2 + 1000000.0 * sphere(y[1:])

def multipeak(x: np.ndarray) -> float:
    """Inspired by M. Gallagher's Gaussian peaks function."""
    v = 10000.
    for a in range(101):
        x_ = np.asarray([np.cos(a + np.sqrt(i)) for i in range(len(x))])
        v = min(v, a / 101. + np.exp(sphere(x - x_)))
    return v

def altellipsoid(y: np.ndarray) -> float:
    """Similar to Ellipsoid, but variables in inverse order.

    E.g. for pointing out algorithms not invariant to the order of variables."""
    return ellipsoid(y[::-1])

def step(s: float) -> float:
    return float(np.exp(int(np.log(s))))

def stepellipsoid(x: np.ndarray) -> float:
    """Classical example of ill conditioned function.

    But we add a 'step', i.e. we set the gradient to zero everywhere.
    Compared to some existing testbeds, we decided to have infinitely many steps.
    dim = x.size
    weights = 10 ** np.linspace(0, 6, dim)
    return float(step(weights.dot(x ** 2)))

def ellipsoid(x: np.ndarray) -> float:
    """Classical example of ill conditioned function.

    The other classical example is cigar.
    dim = x.size
    weights = 10 ** np.linspace(0, 6, dim)
    return float(weights.dot(x ** 2))

def rastrigin(x: np.ndarray) -> float:
    """Classical multimodal function."""
    cosi = float(np.sum(np.cos(2 * np.pi * x)))
    return float(10 * (len(x) - cosi) + sphere(x))

def bucherastrigin(x: np.ndarray) -> float:
    """Classical multimodal function. No box-constraint penalization here."""
    s = np.asarray([x[i] * (10 if x[i] > 0. and i % 2 else 1) * (10**((i - 1) / (2 * (len(x) - 1)))) for i in range(len(x))])
    cosi = float(np.sum(np.cos(2 * np.pi * s)))
    return float(10 * (len(x) - cosi) + sphere(s))

def doublelinearslope(x: np.ndarray) -> float:
    """We decided to use two linear slopes rather than having a constraint artificially added for
    not having the optimum at infinity."""
    return float(np.abs(np.sum(x)))

def stepdoublelinearslope(x: np.ndarray) -> float:
    return step(np.abs(np.sum(x)))

def hm(x: np.ndarray) -> float:
    """New multimodal function (proposed for Nevergrad)."""
    return float((x ** 2).dot(1.1 + np.cos(1.0 / x)))

def rosenbrock(x: np.ndarray) -> float:
    x_m_1 = x[:-1] - 1
    x_diff = x[:-1] ** 2 - x[1:]
    return float(100 * x_diff.dot(x_diff) + x_m_1.dot(x_m_1))

def ackley(x: np.ndarray) -> float:
    dim = x.size
    sum_cos = np.sum(np.cos(2 * np.pi * x))
    return -20.0 * exp(-0.2 * sqrt(sphere(x) / dim)) - exp(sum_cos / dim) + 20 + exp(1)

def schwefel_1_2(x: np.ndarray) -> float:
    cx = np.cumsum(x)
    return sphere(cx)

def griewank(x: np.ndarray) -> float:
    """Multimodal function, often used in Bayesian optimization."""
    part1 = sphere(x)
    part2 = np.prod(np.cos(x / np.sqrt(1 + np.arange(len(x)))))
    return 1 + (float(part1) / 4000.0) - float(part2)

def deceptiveillcond(x: np.ndarray) -> float:
    """An extreme ill conditioned functions. Most algorithms fail on this.

    The condition number increases to infinity as we get closer to the optimum."""
    assert len(x) >= 2
    return float(
        max(np.abs(np.arctan(x[1] / x[0])), np.sqrt(x[0] ** 2.0 + x[1] ** 2.0), 1.0 if x[0] > 0 else 0.0) if x[0] != 0.0 else float("inf")

def deceptivepath(x: np.ndarray) -> float:
    """A function which needs following a long path. Most algorithms fail on this.

    The path becomes thiner as we get closer to the optimum."""
    assert len(x) >= 2
    distance = np.sqrt(x[0] ** 2 + x[1] ** 2)
    if distance == 0.0:
        return 0.0
    angle = np.arctan(x[0] / x[1]) if x[1] != 0.0 else np.pi / 2.0
    invdistance = (1.0 / distance) if distance > 0.0 else 0.0
    if np.abs(np.cos(invdistance) - angle) > 0.1:
        return 1.0
    return float(distance)

def deceptivemultimodal(x: np.ndarray) -> float:
    """Infinitely many local optima, as we get closer to the optimum."""
    assert len(x) >= 2
    distance = np.sqrt(x[0] ** 2 + x[1] ** 2)
    if distance == 0.0:
        return 0.0
    angle = np.arctan(x[0] / x[1]) if x[1] != 0.0 else np.pi / 2.0
    invdistance = int(1.0 / distance) if distance > 0.0 else 0.0
    if np.abs(np.cos(invdistance) - angle) > 0.1:
        return 1.0
    return float(distance)

def lunacek(x: np.ndarray) -> float:
    """Multimodal function.

    Based on https://www.cs.unm.edu/~neal.holts/dga/benchmarkFunction/lunacek.html."""
    problemDimensions = len(x)
    s = 1.0 - (1.0 / (2.0 * np.sqrt(problemDimensions + 20.0) - 8.2))
    mu1 = 2.5
    mu2 = -np.sqrt(abs((mu1 ** 2 - 1.0) / s))
    firstSum = 0.0
    secondSum = 0.0
    thirdSum = 0.0
    for i in range(problemDimensions):
        firstSum += (x[i] - mu1) ** 2
        secondSum += (x[i] - mu2) ** 2
        thirdSum += 1.0 - np.cos(2 * np.pi * (x[i] - mu1))
    return min(firstSum, 1.0 * problemDimensions + secondSum) + 10 * thirdSum

# following functions using discretization should not be used with translation/rotation

def hardonemax(y: np.ndarray) -> float:
    """Onemax, with a discretization in 2 by threshold 0 (>0 or <0)."""
    return onemax(discretization.threshold_discretization(y))

def hardjump(y: np.ndarray) -> float:
    """Hardjump, with a discretization in 2 by threshold 0 (>0 or <0)."""
    return jump(discretization.threshold_discretization(y))

def hardleadingones(y: np.ndarray) -> float:
    """Leading ones, with a discretization in 2 by threshold 0 (>0 or <0)."""
    return leadingones(discretization.threshold_discretization(y))

def hardonemax5(y: np.ndarray) -> float:
    """Hardonemax, with a discretization by 5 with 4 thresholds (quantiles of Gaussian)."""
    return onemax(discretization.threshold_discretization(y, 5))

def hardjump5(y: np.ndarray) -> float:
    """Jump, with a discretization by 5 with 4 thresholds (quantiles of Gaussian)."""
    return jump(discretization.threshold_discretization(y, 5))

def hardleadingones5(y: np.ndarray) -> float:
    """Leadingones, with a discretization by 5 with 4 thresholds (quantiles of Gaussian)."""
    return leadingones(discretization.threshold_discretization(y, 5))

def onemax5(y: np.ndarray) -> float:
    """Softmax discretization of onemax with 5 possibles values.

    This multiplies the dimension by 5."""
    return onemax(discretization.Encoder(y.reshape(-1, 5), np.random).encode().tolist())

def jump5(y: np.ndarray) -> float:
    """Softmax discretization of jump with 5 possibles values.

    This multiplies the dimension by 5."""
    return jump(discretization.Encoder(y.reshape(-1, 5), np.random).encode().tolist())

def leadingones5(y: np.ndarray) -> float:
    """Softmax discretization of leadingones with 5 possibles values.

    This multiplies the dimension by 5."""
    return leadingones(discretization.Encoder(y.reshape(-1, 5), np.random).encode().tolist())

def genzcornerpeak(y: np.ndarray) -> float:
    """One of the Genz functions, originally used in integration,

    tested in optim because why not."""
    value = float(1 + np.mean(np.tanh(y)))
    if value == 0:
        return float("inf")
    return value ** (-len(y) - 1)

def minusgenzcornerpeak(y: np.ndarray) -> float:
    """One of the Genz functions, originally used in integration,

    tested in optim because why not."""
    return -genzcornerpeak(y)

def genzgaussianpeakintegral(x: np.ndarray) -> float:
    """One of the Genz functions, originally used in integration,

    tested in optim because why not."""
    return exp(-sphere(x) / 4.0)

def minusgenzgaussianpeakintegral(x: np.ndarray) -> float:
    """One of the Genz functions, originally used in integration,

    tested in optim because why not."""
    return -genzgaussianpeakintegral(x)

def slope(x: np.ndarray) -> float:
    return sum(x)

def linear(x: np.ndarray) -> float:
    return tanh(x[0])

def st0(x: np.ndarray) -> float:
    """Styblinksitang function with 0 noise."""
    return _styblinksitang(x, 0)

def st1(x: np.ndarray) -> float:
    """Styblinksitang function with noise 1."""
    return _styblinksitang(x, 1)

def st10(x: np.ndarray) -> float:
    """Styblinksitang function with noise 10."""
    return _styblinksitang(x, 10)

def st100(x: np.ndarray) -> float:
    """Styblinksitang function with noise 100."""
    return _styblinksitang(x, 100)