import numpy as np
import math
from typing import Tuple, Optional
import pandas as pd

from .callback import Callback
from ..models.abs_model import AbsModel
from ...plotting.plot_settings import PlotSettings

import seaborn as sns
import matplotlib.pyplot as plt

__all__ = ['LRFinder']


class LRFinder(Callback):
    r'''
    Callback class for Smith learning-rate range test (https://arxiv.org/abs/1803.09820)

    Arguments:
        nb: number of batches in a (sub-)epoch
        lr_bounds: tuple of initial and final LR
        model: :class:`~limin.nn.models.Model` to alter, alternatively call :meth:`set_model`
        plot_settings: :class:`~lumin.plotting.plot_settings.PlotSettings` class to control figure appearance
    '''

    def __init__(self, nb:int, lr_bounds:Tuple[float,float]=[1e-7, 10], model:Optional[AbsModel]=None, plot_settings:PlotSettings=PlotSettings()):
        super().__init__(model=model, plot_settings=plot_settings)
        self.lr_bounds = lr_bounds
        self.lr_mult = (self.lr_bounds[1]/self.lr_bounds[0])**(1/nb)
        
    def on_train_begin(self, **kargs) -> None:
        r'''
        Prepares variables and optimiser for new training
        '''

        self.best,self.iter = math.inf,0
        self.model.set_lr(self.lr_bounds[0])
        self.history = {'loss': [], 'lr': []}
        
    def _calc_lr(self): return self.lr_bounds[0]*(self.lr_mult**self.iter)
    
    def plot(self, n_skip:int=0, n_max:Optional[int]=None, lim_y:Optional[Tuple[float,float]]=None) -> None:
        r'''
        Plot the loss as a function of the LR.

        Arguments:
            n_skip: Number of initial iterations to skip in plotting
            n_max: Maximum iteration number to plot
            lim_y: y-range for plotting
        '''

        # TODO: Decide on whether to keep this; could just pass to plot_lr_finders

        with sns.axes_style(self.plot_settings.style), sns.color_palette(self.plot_settings.cat_palette):
            plt.figure(figsize=(self.plot_settings.w_mid, self.plot_settings.h_mid))
            plt.plot(self.history['lr'][n_skip:n_max], self.history['loss'][n_skip:n_max], label='Training loss', color='g')
            if np.log10(self.lr_bounds[1])-np.log10(self.lr_bounds[0]) >= 3: plt.xscale('log')
            plt.ylim(lim_y)
            plt.grid(True, which="both")
            plt.legend(loc=self.plot_settings.leg_loc, fontsize=self.plot_settings.leg_sz)
            plt.xticks(fontsize=self.plot_settings.tk_sz, color=self.plot_settings.tk_col)
            plt.yticks(fontsize=self.plot_settings.tk_sz, color=self.plot_settings.tk_col)
            plt.ylabel("Loss", fontsize=self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col)
            plt.xlabel("Learning rate", fontsize=self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col)
            plt.show()
        
    def plot_lr(self) -> None:
        r'''
        Plot the LR as a function of iterations.
        '''

        with sns.axes_style(self.plot_settings.style), sns.color_palette(self.plot_settings.cat_palette):
            plt.figure(figsize=(self.plot_settings.h_small, self.plot_settings.h_small))
            plt.plot(range(len(self.history['lr'])), self.history['lr'])
            plt.xticks(fontsize=self.plot_settings.tk_sz, color=self.plot_settings.tk_col)
            plt.yticks(fontsize=self.plot_settings.tk_sz, color=self.plot_settings.tk_col)
            plt.ylabel("Learning rate", fontsize=self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col)
            plt.xlabel("Iterations", fontsize=self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col)
            plt.show()

    def get_df(self) -> pd.DataFrame:
        r'''
        Returns a DataFrame of LRs and losses
        '''

        return pd.DataFrame({'LR': self.history['lr'], 'Loss': self.history['loss']})

    def on_batch_end(self, loss:float, **kargs) -> None:
        r'''
        Records loss and increments LR

        Arguments:
            loss: training loss for most recent batch
        '''

        self.history['loss'].append(loss)
        self.history['lr'].append(self.model.opt.param_groups[0]['lr'])
        self.iter += 1
        lr = self._calc_lr()
        self.model.opt.param_groups[0]['lr'] = lr
        if math.isnan(loss) or loss > self.best*100 or lr > self.lr_bounds[1]: self.model.stop_train = True
        if loss < self.best and self.iter > 10: self.best = loss