from collections import OrderedDict
from copy import copy
from typing import Tuple, Sequence, Dict, Iterable, Union
from warnings import warn

import torch

from torch.nn import Parameter, ModuleDict, ParameterDict

from torch_kalman.internals.batch import Batchable
from torch_kalman.covariance import CovarianceFromLogCholesky, PartialCovarianceFromLogCholesky
from torch_kalman.internals.utils import infer_forward_kwargs

from torch_kalman.process import Process
from lazy_object_proxy.utils import cached_property

from torch_kalman.process.utils.design_matrix import (
    DynamicMatrix,
    TransitionMatrix,
    MeasureMatrix,
    ProcessVarianceMultiplierMatrix,
    MeasureVarianceMultiplierMatrix
)
from torch_kalman.internals.repr import NiceRepr
from torch_kalman.process.utils.design_matrix.utils import adjustments_from_nn
from torch_kalman.utils.nn import NamedEmbedding
from torch_kalman.utils.nn.fourier_season import FourierSeasonNN


class Design(NiceRepr, Batchable):
    """
    A class for specifying the 'design' of a KalmanFilter -- i.e. what measures are modeled by what processes.
    """
    _repr_attrs = ('process_list', 'measures')

    def __init__(self,
                 processes: Sequence[Process],
                 measures: Sequence[str],
                 measure_var_predict: Sequence[torch.nn.Module] = (),
                 process_var_predict: Sequence[torch.nn.Module] = ()
                 ):
        """
        :param processes: Processes
        :param measures: Measure-names
        :param measure_var_predict: See documentation for KalmanFilter.
        :param process_var_predict: See documentation for KalmanFilter.
        """
        self.measures = tuple(measures)

        self.processes = OrderedDict()
        for process in processes:
            if process.id in self.processes.keys():
                raise ValueError(f"Duplicate process-ids: {process.id}.")
            self.processes[process.id] = process

        self._validate()

        # process-variance predictions:
        self._process_var_nn = self._standardize_var_nn(process_var_predict, var_type='process', top_level=True)

        # measure-variance predictions:
        self._measure_var_nn = self._standardize_var_nn(measure_var_predict, var_type='measure', top_level=True)

        # params:

        # initial:
        self._initial_mean = None
        self.init_mean_params = Parameter(.1 * torch.randn(len(self.state_elements)))
        self.init_covariance = PartialCovarianceFromLogCholesky(
            full_dim_names=self.state_elements,
            partial_dim_names=self.unfixed_state_elements
        )

        # process:
        self.process_covariance = PartialCovarianceFromLogCholesky(
            full_dim_names=self.state_elements,
            partial_dim_names=self.dynamic_state_elements
        )

        # measure:
        self.measure_covariance = CovarianceFromLogCholesky(rank=len(self.measures))
        self._measure_var_adjustments = MeasureVarianceMultiplierMatrix(self.measures)

    @cached_property
    def state_elements(self) -> Sequence[Tuple[str, str]]:
        out = []
        for process_name, process in self.processes.items():
            out.extend((process_name, state_element) for state_element in process.state_elements)
        return out

    @cached_property
    def dynamic_state_elements(self) -> Sequence[Tuple[str, str]]:
        out = []
        for process_name, process in self.processes.items():
            out.extend((process_name, state_element) for state_element in process.dynamic_state_elements)
        return out

    @cached_property
    def unfixed_state_elements(self) -> Sequence[Tuple[str, str]]:
        out = []
        for process_name, process in self.processes.items():
            out.extend((process_name, state_element) for state_element in process.state_elements
                       if state_element not in process.fixed_state_elements)
        return out

    @cached_property
    def process_slices(self) -> Dict[str, slice]:
        process_slices = OrderedDict()
        start_counter = 0
        for process_name, process in self.processes.items():
            end_counter = start_counter + len(process.state_elements)
            process_slices[process_name] = slice(start_counter, end_counter)
            start_counter = end_counter
        return process_slices

    def _validate(self):
        if not self.measures:
            raise ValueError("Empty `measures`")
        if len(self.measures) != len(set(self.measures)):
            raise ValueError("Duplicates in `measures`")
        if not self.processes:
            raise ValueError("Empty `processes`")

        used_measures = set()
        for process_name, process in self.processes.items():
            for measure in process.measures:
                if measure not in self.measures:
                    raise RuntimeError(f"{measure} not in `measures`")
                used_measures.add(measure)

        unused_measures = set(self.measures).difference(used_measures)
        if unused_measures:
            raise ValueError(f"The following `measures` are not in any of the `processes`:\n{unused_measures}")

    # For Batch -------:
    def for_batch(self, num_groups: int, num_timesteps: int, **kwargs) -> 'Design':
        for_batch = copy(self)
        for_batch.processes = OrderedDict()
        for_batch.batch_info = (num_groups, num_timesteps)
        for_batch._initial_mean = torch.zeros(num_groups, len(self.state_elements))

        batch_dim_kwargs = {'num_groups': num_groups, 'num_timesteps': num_timesteps}

        unused_kwargs = set(kwargs.keys())

        # processes:
        for process_name, process in self.processes.items():
            proc_kwargs, used = self._parse_kwargs(
                batch_kwargs=process.batch_kwargs(),
                prefix=process.id,
                all_kwargs=kwargs,
                aliases=getattr(process, 'batch_kwargs_aliases', {})
            )
            for k in used:
                unused_kwargs.discard(k)

            # wrap calls w/process-name for easier tracebacks:
            try:
                for_batch.processes[process_name] = process.for_batch(**batch_dim_kwargs, **proc_kwargs)
                for_batch._initial_mean[:, self.process_slices[process_name]] = process.initial_state_means_for_batch(
                    parameters=self.init_mean_params[self.process_slices[process_name]],
                    num_groups=num_groups,
                    **proc_kwargs
                )
            except Exception as e:
                # add process-name to traceback
                raise type(e)(f"Failed to create `{process}.for_batch()` (see traceback above).") from e

            if for_batch.processes[process_name] is None:
                raise RuntimeError(f"{process_name}'s `for_batch` call did not return anything.")

        # var adjustments:
        for_batch._measure_var_adjustments = self._measure_var_adjustments.for_batch(**batch_dim_kwargs)
        for var_type, nn_list in {'measure': self._measure_var_nn, 'process': self._process_var_nn}.items():
            for i, nn in enumerate(nn_list):
                nn_kwargs, used = self._parse_kwargs(
                    prefix=f'{var_type}_var_nn{i}',
                    batch_kwargs=nn._forward_kwargs,
                    all_kwargs={**kwargs, **batch_dim_kwargs},
                    aliases=getattr(nn, '_forward_kwargs_aliases', {})
                )

                # a cheat that makes the `seasonal` alias more convenient:
                if 'datetimes' in nn._forward_kwargs and 'datetimes' not in nn_kwargs and hasattr(nn, '_dt_helper'):
                    if 'start_datetimes' in kwargs:
                        nn_kwargs['datetimes'] = nn._dt_helper.make_grid(kwargs['start_datetimes'], num_timesteps)

                for k in used:
                    unused_kwargs.discard(k)

                adjustments = adjustments_from_nn(
                    nn=nn,
                    **batch_dim_kwargs,
                    nn_kwargs=nn_kwargs,
                    output_names=self.measures if var_type == 'measure' else self.dynamic_state_elements,
                    time_split_kwargs=getattr(nn, '_time_split_kwargs', ())
                )

                for el, adj in adjustments.items():
                    for_batch._adjust_variance(el, adjustment=adj, check_slow_grad=False)

        if unused_kwargs:
            warn("Unexpected keyword arguments: {}".format(unused_kwargs))

        return for_batch

    @property
    def initial_mean(self):
        if self.is_for_batch:
            return self._initial_mean
        else:
            raise RuntimeError(
                f"Tried to access `{type(self).__name__}.initial_mean`, but only possible for output of `for_batch()`."
            )

    # Parameters -------:
    def param_dict(self) -> ModuleDict:
        p = ModuleDict()
        for process_name, process in self.processes.items():
            p[f"process:{process_name}"] = process.param_dict()

        p['measure_cov'] = self.measure_covariance.param_dict()
        p['measure_var_nn'] = self._measure_var_nn

        p['init_state'] = ParameterDict([('mean', self.init_mean_params)])
        p['init_state'].update(self.init_covariance.param_dict().items())

        p['process_cov'] = self.process_covariance.param_dict()
        p['process_var_nn'] = self._process_var_nn

        return p

    # Transition Matrix -------:
    @cached_property
    def F(self) -> DynamicMatrix:
        merged = TransitionMatrix.merge([(nm, process.transition_mat) for nm, process in self.processes.items()])
        assert list(merged.from_elements) == list(self.state_elements) == list(merged.to_elements)
        return merged.compile()

    # Measurement Matrix ------:
    @cached_property
    def H(self) -> DynamicMatrix:
        merged = MeasureMatrix.merge([(nm, process.measure_mat) for nm, process in self.processes.items()])
        assert list(merged.state_elements) == list(self.state_elements)
        # order dim:
        assert set(merged.measures) == set(self.measures)
        merged.measures[:] = self.measures
        return merged.compile()

    # Process-Covariance Matrix ------:
    def Q(self, t: int) -> torch.Tensor:
        # processes can apply multipliers to the variance of their state-elements:
        diag_multi = self._process_variance_multi(t=t)
        return diag_multi.matmul(self._base_Q).matmul(diag_multi)

    @cached_property
    def _process_variance_multi(self) -> DynamicMatrix:
        merged = ProcessVarianceMultiplierMatrix.merge(
            [(nm, process.variance_multi_mat) for nm, process in self.processes.items()]
        )
        assert list(merged.state_elements) == list(self.state_elements)
        return merged.compile()

    @cached_property
    def _base_Q(self):
        Q = self.process_covariance.create(leading_dims=())

        # process covariance is scaled by the variances of the measurement-variances:
        Q_rescaled = self._scale_covariance(Q)

        # expand for batch-size:
        return Q_rescaled.expand(self.num_groups, -1, -1)

    # Measure-Covariance Matrix ------:
    def R(self, t: int):
        diag_multi = self._measure_variance_multi(t=t)
        return diag_multi.matmul(self._base_R).matmul(diag_multi)

    @cached_property
    def _measure_variance_multi(self) -> DynamicMatrix:
        return self._measure_var_adjustments.compile()

    @cached_property
    def _base_R(self):
        return self.measure_covariance.create(leading_dims=(self.num_groups,))

    # Initial Cov ------:
    @cached_property
    def initial_covariance(self) -> torch.Tensor:
        init_cov = self.init_covariance.create(leading_dims=())
        # init covariance is scaled by the variances of the measurement-variances:
        init_cov_rescaled = self._scale_covariance(init_cov)
        # expand for batch-size:
        return init_cov_rescaled.expand(self.num_groups, -1, -1)

    def _scale_covariance(self, cov: torch.Tensor) -> torch.Tensor:
        """
        Rescale variances associated with processes (process-covariance or initial covariance) by the
        measurement-variances. Helpful in practice for training.
        """
        measure_idx_by_measure = {measure: i for i, measure in enumerate(self.measures)}
        measure_log_stds = self.measure_covariance.create().diag().sqrt().log()
        diag_flat = torch.ones(len(self.state_elements))
        for process_name, process in self.processes.items():
            measure_idx = [measure_idx_by_measure[m] for m in process.measures]
            diag_flat[self.process_slices[process_name]] = measure_log_stds[measure_idx].mean().exp()
        diag_multi = torch.diagflat(diag_flat)
        cov_rescaled = diag_multi.matmul(cov).matmul(diag_multi)
        return cov_rescaled

    @property
    def process_list(self):
        return list(self.processes.values())

    # Private -----:
    def _parse_kwargs(self,
                      prefix: str,
                      all_kwargs: dict,
                      batch_kwargs: Iterable[str],
                      aliases: dict) -> Tuple[dict, set]:
        too_generic = {'input', 'x'}

        # use sklearn-style disambiguation:
        used = set()
        out = {}
        for k in batch_kwargs:
            specific_key = "{}__{}".format(prefix, k)
            if specific_key in all_kwargs:
                out[k] = all_kwargs[specific_key]
                used.add(specific_key)
            elif k in all_kwargs:
                if k in too_generic:
                    raise ValueError(
                        f"The argument `{k}` is too generic, so it needs to be passed in a way that specifies which "
                        f"process it should be handed off to (e.g. {specific_key})."
                    )
                out[k] = all_kwargs[k]
                used.add(k)
            else:
                alias = aliases.get(k) or aliases.get(specific_key)
                if alias in all_kwargs:
                    out[k] = all_kwargs[alias]
                    used.add(alias)
        return out, used

    def _standardize_var_nn(self,
                            var_nn: Union[torch.nn.Module, Sequence],
                            var_type: str,
                            top_level: bool = False) -> torch.nn.Module:

        if top_level:
            if isinstance(var_nn, torch.nn.ModuleList):
                return var_nn

            if callable(var_nn):
                # they passed a single NN instead of a list, wrap it:
                var_nn = [var_nn]
            elif len(var_nn) > 0 and isinstance(var_nn[0], str):
                # they passed a single alias instead of a list, wrap it:
                var_nn = [var_nn]

            return torch.nn.ModuleList([self._standardize_var_nn(sub_nn, var_type) for sub_nn in var_nn])
        else:
            if callable(var_nn):
                out_nn = var_nn
            elif isinstance(var_nn, (tuple, list)):
                alias, args_or_kwargs = var_nn
                num_outputs = len(self.measures if var_type == 'measure' else self.dynamic_state_elements)
                if alias == 'per_group' and isinstance(args_or_kwargs, int):
                    args_or_kwargs = (args_or_kwargs,)
                if isinstance(args_or_kwargs, dict):
                    args, kwargs = (), args_or_kwargs
                else:
                    args, kwargs = args_or_kwargs, {}

                if alias == 'per_group':
                    if 'embedding_dim' not in kwargs:
                        kwargs['embedding_dim'] = num_outputs
                    out_nn = NamedEmbedding(*args, **kwargs)
                    out_nn._forward_kwargs_aliases = {'input': 'group_names'}
                elif alias == 'seasonal':
                    out_nn = FourierSeasonNN(*args, **kwargs, num_outputs=num_outputs)
                    out_nn._time_split_kwargs = ['datetimes']
                else:
                    raise ValueError(f"Known aliases are 'per_group' and 'seasonal'; got '{alias}'")
            else:
                raise TypeError(
                    f"Expected `{var_type}_var_nn` to be a callable/torch.nn.Module, or a tuple with format "
                    f"`('alias',(arg1,arg2,...)`. Instead got `{type(var_nn)}`."
                )
            if not hasattr(out_nn, '_forward_kwargs'):
                out_nn._forward_kwargs = infer_forward_kwargs(out_nn)
            if not hasattr(out_nn, '_forward_kwargs_aliases'):
                out_nn._forward_kwargs_aliases = {}
            return out_nn

    def _adjust_variance(self,
                         *args,
                         adjustment: 'DesignMatAdjustment',
                         check_slow_grad: bool = True,
                         ):
        if len(args) == 1:
            if isinstance(args[0], (list, tuple)):
                args = args[0]
        if len(args) == 1:
            assert args[0] in self.measures
            self._measure_var_adjustments.adjust(value=adjustment, check_slow_grad=check_slow_grad, measure=args[0])
        else:
            process, state_element = args
            self.processes[process]._adjust_variance(
                state_element=state_element, adjustment=adjustment, check_slow_grad=check_slow_grad
            )