"""
scalarnl.py:  Estimation methods for scalar nonlinear functions
"""
from __future__ import division

import numpy as np

# Import other subpackages in vampyre
import vampyre.common as common

# Import individual classes and methods from the current sub-package
from vampyre.estim.base import BaseEst

class ScalarNLEst(BaseEst):
    """
    Base class for an esitmator for a general scalar nonlinear penalty
    
    This esitmator corresponds to a general nonlinear penalty :math:`f(z)`.
    The penalty function is defined in the derived class by implementing
    the method :code:`fnl`.  Right now, the function only implements
    MAP estimation which it performs via Newton's method.
    
    :param shape:  Shape of :math:`z`.
    :param var_axes:  The axes on which the input variance is repeated.
    :param step_init:  Intial gradient descent step-size
    :param step_max:  Max gradient descent step-size
    :param step_min:  Min gradient descent step-size
    :param max_it:  Max number of gradient-descent iterations per call to the 
        estimation method
    :param zinit:  Initial estimate for :math:`z`
    :param is_complex:  Flag indicating if :math:`z` is complex.
    :param gtol:  gradient norm tolerance.
        
    """    
    def __init__(self,shape,var_axes=(0,),max_it=10, step_init=1,\
        dtype=np.float64,name=None,\
        step_max=1,step_min=1e-8,zinit=None,is_complex=False,gtol=1e-3):
        
        BaseEst.__init__(self,shape=shape,var_axes=var_axes,dtype=dtype,\
            name=name, type_name='ScalarNL', nvars=1, cost_avail=True)

        # Save parameters
        self.max_it = max_it
        self.cost_avail = True
        self.step_last = step_init
        self.step_max = step_max
        self.step_min = step_min
        self.is_complex = is_complex
        self.gtol = gtol
        self.min_it = 1
        
        # Set default value that the Hessian is available
        self.hess_avail = True
                
        # Set initial point
        if (zinit is None):
            self.zlast = self.proj(np.zeros(self.shape))
        else:
            self.zlast = zinit
            
    def proj(self,z):
        """
        Projects any value :math:`z` onto the feasible set.  
        The default implementation performs no projection
        """
        return z
        
    def fnl(self,z):
        """
        Penalty function.  This must be implemented in the derived class.
        
        :param z:  The input to the penalty function.
        
        :returns:  :code:`f,fgrad,[fhess]` The function, its gradient and
           second derivative.  The second derivative must be implemented
           only in the case where :code:`self.hess_avail == True`.
        """
        raise NotImplementedError()
            
    def fnl_aug(self,z,r,rvar):
        """
        Augmented nonlinear function and its gradient.
        
        Given the penalty function :math:`f(z)`, the function returns
        the value and gradient of the augmented function,
        
        :math:`f_{aug}(z,r) = f(z) + (1/\\tau_r)|r-z|^2            
        """
        
        # Evaluate function
        if self.hess_avail:        
            f0, fgrad0, fhess0 = self.fnl(z)
        else:
            f0, fgrad0 = self.fnl(z)
                    
        # Add the augmenting term
        rvar_rep = common.repeat_axes(rvar,self.shape,self.var_axes,rep=False)
        aug = np.abs(z-r)**2/rvar_rep
        aug_grad = 2*np.conj(z-r)/rvar_rep
        aug_hess = 2/rvar_rep
        if not self.is_complex:
            aug /= 2
            aug_grad /= 2
            aug_hess /= 2
        faug = np.sum(f0 + aug)
        faug_grad = fgrad0 + aug_grad
        if self.hess_avail:
            faug_hess = fhess0 + aug_hess                         
            return faug, faug_grad, faug_hess
        else:
            return faug, faug_grad            
        
    def est_init(self, return_cost=False,ind_out=None,\
        avg_var_cost=True):
        """
        Initial estimator.
        
        See the base class :class:`vampyre.estim.base.Estim` for 
        a complete description.  
        
        The default implementation calls :code:`est` with the initial variance
        (which is typically large)
        
        :param boolean return_cost:  Flag indicating if :code:`cost` is 
            to be returned
        :returns: :code:`zmean, zvar, [cost]` which are the
            prior mean and variance
        """        
        return self.est(self.rinit,self.rvar_init,return_cost,ind_out,\
            avg_var_cost)

    def est(self,r,rvar,return_cost=False,ind_out=None,\
        avg_var_cost=True):
        """
        Estimation function
        
        The proximal estimation function as 
        described in the base class :class:`vampyre.estim.base.Estim`
                
        :param r: Proximal mean
        :param rvar: Proximal variance
        :param Boolean return_cost:  Flag indicating if :code:`cost` is 
            to be returned
        
        :returns: :code:`zhat, zhatvar, [cost]` which are the posterior
            mean, variance and optional cost.
        """        
        
        # Check parameters
        if (ind_out != [0]) and (ind_out != None):
            raise ValueError("ind_out must be either [0] or None")
        if not avg_var_cost:
            raise ValueError("disabling variance averaging not supported for MixEst")

        # Check that the Hessian is available
        if not self.hess_avail:
            raise Exception("Second derivative must be currently supported.  "+\
                    "This is needed for the variance computation in the "+\
                    "current implementation.")
        
        
        # Get point from previous run
        z = self.zlast
        step = self.step_last
        
        # Compute initial function and step direction
        if self.hess_avail:
            f, fgrad, fhess = self.fnl_aug(z,r,rvar)
            g = fgrad / fhess
            zhatvar = 1 / fhess
        else:
            f, fgrad = self.fnl_aug(z,r,rvar)
            g = fgrad
            
        a = 0.5
        done = False
        it = 0
        while not done:        
            # Try test point
            z1 = z - step*g
            z1 = self.proj(z1)
            if self.hess_avail:
                f1, fgrad1, fhess1 = self.fnl_aug(z1,r,rvar)                
                g1 = fgrad1/fhess1
            else:
                f1, fgrad1 = self.fnl_aug(z1,r,rvar)
                g1 = fgrad1
            
            
            # Compute expected decrease
            dfest = np.sum(np.conj(fgrad)*(z1-z))
            
            # Accept or reject point
            if (f1-f < a*dfest) and (dfest < 0):
                z = z1
                f = f1
                fgrad = fgrad1
                g = g1
                step = 2*step
                step = np.minimum(self.step_max, step)
                zhatvar = 1/fhess1
                
            else:
                step = 0.5*step
                step = np.maximum(self.step_min, step)
                
            # Check termination
            it += 1
            gnorm = np.mean(np.abs(g)**2)
            done = (it >= self.max_it) or (gnorm < self.gtol)
            done = done and (it >= self.min_it)
                            
        # Save results
        self.zlast = z
        self.step_last = step
        self.nit_last = it
                
        # Average the variance
        zhatvar = np.mean(zhatvar, axis=self.var_axes)
        if self.is_complex:
            zhatvar /= 2
                    
        if return_cost:
            return z, zhatvar, f
        else:
            return z, zhatvar
            
class LogisticEst(ScalarNLEst):
    def __init__(self,y,var_axes=(0,),max_it=100,gtol=1e-6,\
        rinit=None,rvar_init=10,name=None,dtype=np.float64):
        """
        Lotistic estimator with binary class label.
        
        The penalty function is given by,
        
        :math:`f(z,y) = z - zy + \\log(1 + \\exp(-z))`
        
        To avoid overflow, this is implemented as
        
        :math:`f(z,y) = \\max(z,0) - zy + \\log(1 + \\exp(-|z|))`
                    
        :param y:  Binary class labels 0 or 1.
        :param nit_max:  Maximum number of Newton iterations
        :param gtol:  Stopping tolerance for optimization
        :param rinit:  Initial prior mean on z for :code:`est_init`
        :param rvar_init:  Initial prior variance on z for :code:`est_init`
        
        :note:  The penalty matches the Tensorflow operator
           :code:`tf.nn.sigmoid_cross_entropy_with_logits`
        """
        # Save parameters
        shape = y.shape
        if rinit is None:
            self.rinit = np.zeros(shape)
        else:
            self.rinit = rinit
        
        # Intialize the base class              
        #ScalarNLEst.__init__(self,shape,var_axes,max_it=max_it,\
        #    step_init=1, step_max=1,step_min=1e-8,zinit=rinit)
        ScalarNLEst.__init__(self,shape,var_axes=var_axes,max_it=max_it,\
            step_init=1,dtype=dtype,name=name,\
            step_max=1,step_min=1e-8,zinit=self.rinit)
            
        if np.isscalar(rvar_init):
            var_shape = common.get_var_shape(self.shape,self.var_axes)
            rvar_init = np.tile(rvar_init, var_shape)
        self.rvar_init = rvar_init                        
        
        # Indicate that the Hessian is available
        self.hess_avail = True
        
        # Save class label
        self.y = y

    def fnl(self,z):
        """
        Logistic function
        
        :returns:  :code:`f,fgrad,fhess` the logistic function, its gradient
           and hessian
        """
        p = 1/(1+np.exp(-z))
        q = 1/(1+np.exp(z))
        f = np.maximum(z,0) - self.y*z + np.log(1+np.exp(-np.abs(z)))
        fgrad = p-self.y 
        fhess = p*q        
        return f, fgrad, fhess