import numpy as np
import scipy.sparse.linalg as ssl
import util
from collections import namedtuple
import nn


def adagrad(grad_func, x0, learning_rate, eps=1e-8):
    x = x0.copy()
    cache = np.zeros_like(x)
    while True:
        g = grad_func(x)
        cache += g**2
        x -= learning_rate * g / np.sqrt(cache + eps)
        yield x


def btlinesearch(f, x0, fx0, g, dx, accept_ratio, shrink_factor, max_steps, verbose=False):
    '''
    Find a step size t such that f(x0 + t*dx) is within a factor
    accept_ratio of the linearized function value improvement.

    Args:
        f: the function
        x0: starting point for search
        fx0: the value f(x0). Will be computed if set to None.
        g: search direction, typically the gradient of f at x0
        dx: the largest possible step to take
        accept_ratio: termination criterion
        shrink_factor: how much to decrease the step every iteration
    '''
    if fx0 is None: fx0 = f(x0)
    t = 1.
    m = g.dot(dx)
    if accept_ratio != 0 and m > 0: util.warn('WARNING: %.10f not <= 0' % m)
    num_steps = 0
    while num_steps < max_steps:
        true_imp = f(x0 + t*dx) - fx0
        lin_imp = t*m
        if verbose: true_imp, lin_imp, accept_ratio
        if true_imp <= accept_ratio * lin_imp:
            break
        t *= shrink_factor
        num_steps += 1
    return x0 + t*dx, num_steps


def numdiff_hvp(v, grad_func, x0, grad0=None, finitediff_delta=1e-4):
    '''
    Approximate Hessian-vector product.

    Uses a 1-dimensional finite difference approximation for the
    directional derivative of the gradient function.

    Args:
        v: the vector to left-multiply by the Hessian
        grad_func: gradient function
        x0: point at which to evaluate the Hessian
        grad0: should equal grad_func(x0), or None to compute in here.
        finitediff_delta: step size for finite difference
    '''
    assert v.shape == x0.shape
    if np.allclose(v, 0): return np.zeros_like(v)
    eps = finitediff_delta / np.linalg.norm(v)
    dx = eps * v
    if grad0 is None: grad0 = grad_func(x0)
    grad1 = grad_func(x0+dx)
    out = grad1 - grad0; out /= eps
    return out


def ngstep(x0, obj0, objgrad0, obj_and_kl_func, hvpx0_func, max_kl, damping, max_cg_iter, enable_bt):
    '''
    Natural gradient step using hessian-vector products

    Args:
        x0: current point
        obj0: objective value at x0
        objgrad0: grad of objective value at x0
        obj_and_kl_func: function mapping a point x to the objective and kl values
        hvpx0_func: function mapping a vector v to the KL Hessian-vector product H(x0)v
        max_kl: max kl divergence limit. Triggers a line search.
        damping: multiple of I to mix with Hessians for Hessian-vector products
        max_cg_iter: max conjugate gradient iterations for solving for natural gradient step
    '''

    assert x0.ndim == 1 and x0.shape == objgrad0.shape

    # Solve for step direction
    damped_hvp_func = lambda v: hvpx0_func(v) + damping*v
    hvpop = ssl.LinearOperator(shape=(x0.shape[0], x0.shape[0]), matvec=damped_hvp_func)
    step, _ = ssl.cg(hvpop, -objgrad0, maxiter=max_cg_iter)
    fullstep = step / np.sqrt(.5 * step.dot(damped_hvp_func(step)) / max_kl + 1e-8)

    # Line search on objective with a hard KL wall
    if not enable_bt:
        return x0+fullstep, 0

    def barrierobj(p):
        obj, kl = obj_and_kl_func(p)
        return np.inf if kl > 2*max_kl else obj
    xnew, num_bt_steps = btlinesearch(
        f=barrierobj,
        x0=x0,
        fx0=obj0,
        g=objgrad0,
        dx=fullstep,
        accept_ratio=.1, shrink_factor=.5, max_steps=10)
    return xnew, num_bt_steps


def subsample_feed(feed, frac):
    assert isinstance(feed, tuple) and len(feed) >= 1
    assert isinstance(frac, float) and 0. < frac <= 1.
    l = feed[0].shape[0]
    assert all(a.shape[0] == l for a in feed), 'All feed entries must have the same length'
    subsamp_inds = np.random.choice(l, size=int(frac*l))
    return tuple(a[subsamp_inds,...] for a in feed)

NGStepInfo = namedtuple('NGStepInfo', 'obj0, kl0, obj1, kl1, gnorm, bt')
def make_ngstep_func(model, compute_obj_kl, compute_obj_kl_with_grad, compute_kl_hvp):
    '''
    Makes a wrapper for ngstep for classes that implement nn.Model
    Subsamples inputs for fast Hessian-vector products
    '''
    assert isinstance(model, nn.Model)

    def wrapper(feed, max_kl, damping, subsample_hvp_frac=.1, grad_stop_tol=1e-6, max_cg_iter=10, enable_bt=True):
        assert isinstance(feed, tuple)

        params0 = model.get_params()
        obj0, kl0, objgrad0 = compute_obj_kl_with_grad(*feed)
        gnorm = util.maxnorm(objgrad0)
        assert np.allclose(kl0, 0), 'Initial KL divergence is %.7f, but should be 0' % (kl0,)
        # Terminate early if gradient is too small
        if gnorm < grad_stop_tol:
            return NGStepInfo(obj0, kl0, obj0, kl0, gnorm, 0)

        # Data subsampling for Hessian-vector products
        subsamp_feed = feed if subsample_hvp_frac is None else subsample_feed(feed, subsample_hvp_frac)
        def hvpx0_func(v):
            with model.try_params(params0):
                hvp_args = subsamp_feed + (v,)
                return compute_kl_hvp(*hvp_args)
        # Objective for line search
        def obj_and_kl_func(p):
            with model.try_params(p):
                obj, kl = compute_obj_kl(*feed)
            return -obj, kl
        params1, num_bt_steps = ngstep(
            x0=params0,
            obj0=-obj0,
            objgrad0=-objgrad0,
            obj_and_kl_func=obj_and_kl_func,
            hvpx0_func=hvpx0_func,
            max_kl=max_kl,
            damping=damping,
            max_cg_iter=max_cg_iter,
            enable_bt=enable_bt)
        model.set_params(params1)
        obj1, kl1 = compute_obj_kl(*feed)
        return NGStepInfo(obj0, kl0, obj1, kl1, gnorm, num_bt_steps)

    return wrapper