"""
Utilities
---------
Module description
"""
import sys
from functools import reduce
from collections import abc
from collections.abc import Iterable

import numpy as np
import torch
import pandas as pd

from brancher.config import device


def is_tensor(data):
    return torch.is_tensor(data)


def contains_tensors(data):
    if isinstance(data, dict):
        return all([is_tensor(d) for d in data.values()])
    if isinstance(data, Iterable):
        return all([is_tensor(d) for d in data])
    else:
        return False


def is_discrete(data):
    return type(data) in [list, set, tuple, dict, str]


def to_tuple(obj):
    if isinstance(obj, Iterable) and not isinstance(obj, torch.Tensor):
        return tuple(obj)
    else:
        return tuple([obj])


def to_tensor(arr):
    if isinstance(arr, torch.Tensor):
        return arr
    elif isinstance(arr, np.ndarray):
        return torch.Tensor(arr)
    else:
        raise ValueError("The input should be either a torch.Tensor or a np.array")


def map_iterable(func, itr, recursive=False):
    def f(x):
        if not recursive:
            return func(x)
        else:
            return map_iterable(func, x, recursive=True)
    if is_tensor(itr) or not isinstance(itr, Iterable):
        return func(itr)
    if isinstance(itr, dict):
        return {key: f(val) for key, val in itr.items()}
    out = [*map(f, itr)]
    if isinstance(itr, list):
        return out
    elif isinstance(itr, tuple):
        return tuple(out)

def zip_dict(first_dict, second_dict):
    keys = set(first_dict.keys()).intersection(set(second_dict.keys()))
    return {k: to_tuple(first_dict[k]) + to_tuple(second_dict[k]) for k in keys}


def zip_dict_list(dict_list):
    if len(dict_list) == 0:
        return {}
    if len(dict_list) == 1:
        return dict_list[0]
    else:
        zipped_dict = zip_dict(dict_list[-1], dict_list[-2])
        new_dict_list = dict_list[:-2]
        new_dict_list.append(zipped_dict)
        return zip_dict_list(new_dict_list)


def split_dict(dic, condition):
    dict_1 = {}
    dict_2 = {}
    for key, val in dic.items():
        if condition(key, val):
            dict_1.update({key: val})
        else:
            dict_2.update({key: val})
    return dict_1, dict_2


def flatten_list(lst):
    flat_list = [item for sublist in lst for item in sublist]
    return flat_list


def flatten_set(st):
    flat_set = set([item for subset in st for item in subset])
    return flat_set


def join_dicts_list(dicts_list):
    if dicts_list:
        return reduce(lambda d1, d2: {**d1, **d2}, dicts_list)
    else:
        return {}


def join_sets_list(sets_list):
    if sets_list:
        return reduce(lambda d1, d2: d1.union(d2), sets_list)
    else:
        return set()


def sum_from_dim(tensor, dim_index):
    if is_tensor(tensor):
        data_dim = len(tensor.shape)
        for dim in reversed(range(dim_index, data_dim)):
            tensor = tensor.sum(dim=dim)
        return tensor
    else:
        return np.sum(tensor, axis=tuple(range(1, len(tensor.shape) - 1)), keepdims=False)[:, 0]

def sum_data_dimensions(var):
    return sum_from_dim(var, dim_index=2)


def partial_broadcast(*args):
    assert all([is_tensor(ar) for ar in args]), 'at least 1 object is not torch tensor'
    shapes0, shapes1 = zip(*[(x.shape[0], x.shape[1]) for x in args])
    s0, s1 = np.max(shapes0), np.max(shapes1)
    return [x.expand((s0, s1) + x.shape[2:]) for x in args]


def tile_batch_dimensions(tensor, number_samples, number_datapoints):
    return tensor.expand((number_samples, number_datapoints) + tensor.shape[2:])


def broadcast_and_squeeze(*args):
    assert all([is_tensor(ar) for ar in args]), 'at least 1 object is not torch tensor'
    if all([np.prod(val.shape[2:]) == 1 for val in args]):
        args = [val.contiguous().view(size=val.shape[:2] + tuple([1, 1])) for val in args]
    uniformed_values = uniform_shapes(*args)
    broadcasted_values = torch.broadcast_tensors(*uniformed_values)
    return broadcasted_values


def broadcast_and_squeeze_mixed(tpl, dic):
    tpl_len = len(tpl)
    dict_keys, dict_values = zip(*dic.items())
    broadcasted_values = broadcast_and_squeeze(*(tpl + dict_values))
    if tpl_len > 0:
        return broadcasted_values[:tpl_len], {k: v for k, v in zip(dict_keys, broadcasted_values[tpl_len:])}
    else:
        return {k: v for k, v in zip(dict_keys, broadcasted_values[tpl_len:])}


def get_items(itr, recursive=False):
    if is_tensor(itr) or not isinstance(itr, Iterable):
        return iter

    def f(x):
        if recursive:
            return get_items(x, recursive=True)
        else:
            return x

    if isinstance(itr, dict):
        items = f(itr.items())
        itr.items()
    else:
        return f(itr)


def reshape_parent_value(value, number_samples, number_datapoints):
    newshape = tuple([number_samples * number_datapoints]) + value.shape[2:]
    return value.contiguous().view(size=newshape)


def broadcast_and_reshape_parent_value(value, number_samples, number_datapoints):
    return reshape_parent_value(tile_batch_dimensions(value, number_samples, number_datapoints),
                                number_samples, number_datapoints)


def get_number_samples_and_datapoints(parent_values):
    n_list = []
    m_list = []
    for value in parent_values.values():
        if is_tensor(value):
            n_list.append(value.shape[0])
            m_list.append(value.shape[1])
        elif contains_tensors(value):
            if isinstance(value, dict):
                itr = value.values()
            else:
                itr = value
            for tensor in itr:
                n_list.append(tensor.shape[0])
                m_list.append(tensor.shape[1])
    if not n_list and not m_list:
        return None, None
    else:
        return max(n_list), max(m_list)


def get_diagonal(tensor):
    assert torch.is_tensor(tensor), 'object is not torch tensor'
    assert tensor.ndimension() == 4, 'ndim should be equal 4'
    dim1, dim2, dim_matrix, _ = tensor.shape
    diag_ind = list(range(dim_matrix))
    expanded_diag_ind = dim1*dim2*diag_ind
    axis12_ind = [a for a in range(dim1*dim2) for _ in range(dim_matrix)]
    reshaped_tensor = tensor.contiguous().view(size=(dim1*dim2, dim_matrix, dim_matrix))
    ind = (np.array(axis12_ind), np.array(expanded_diag_ind), np.array(expanded_diag_ind))
    subdiagonal = reshaped_tensor[ind]
    return subdiagonal.view(size=(dim1, dim2, dim_matrix))


def coerce_to_dtype(data, is_observed=False):
    """Summary"""
    def reformat_tensor(result):
        if is_observed:
            result = torch.unsqueeze(result, dim=0)
            result_shape = result.shape
            if len(result_shape) == 2:
                result = result.contiguous().view(size=result_shape + tuple([1, 1]))
            elif len(result_shape) == 3:
                result = result.contiguous().view(size=result_shape + tuple([1]))
            #if len(result_shape) == 2:
            #   result = result.contiguous().view(size=result_shape + tuple([1]))
        else:
            result = torch.unsqueeze(torch.unsqueeze(result, dim=0), dim=1)
        return result

    dtype = type(data) ##TODO: do we need any additional shape checking?
    if dtype is torch.Tensor: # to tensor
        result = data.float()
    elif dtype is np.ndarray: # to tensor
        result = torch.tensor(data).float()
    elif dtype is pd.DataFrame:
        result = torch.tensor(data.values).float()
    elif dtype in [float, int] or dtype.__base__ in [np.floating, np.signedinteger]: # to tensor
        result = torch.tensor(data * np.ones(shape=(1, 1))).float()
    elif dtype in [list, set, tuple, dict, str]: # to discrete
        return data
    else:
        raise TypeError("Invalid input dtype {} - expected float, integer, np.ndarray, or torch var.".format(dtype))

    result = reformat_tensor(result)
    return result.to(device)


def tile_parameter(tensor, number_samples):
    assert is_tensor(tensor), 'object is not torch tensor'
    value_shape = tensor.shape
    if value_shape[0] == number_samples:
        return tensor
    elif value_shape[0] == 1:
        reps = tuple([number_samples] + [1] * len(value_shape[1:]))
        return tensor.repeat(repeats=reps)
    else:
        raise ValueError("The parameter cannot be broadcasted to the required number of samples")


def reformat_sampler_input(sample_input, number_samples):
    return {var: tile_parameter(coerce_to_dtype(value, is_observed=var.is_observed), number_samples=number_samples)
            for var, value in sample_input.items()}


def uniform_shapes(*args):
    assert all([is_tensor(ar) for ar in args]), 'at least 1 object is not torch tensor'
    shapes = [ar.shape for ar in args]
    max_len = np.max([len(s) for s in shapes])
    return [torch.unsqueeze(ar, dim=len(ar.shape)) if (len(ar.shape) == max_len-1) else ar
            for ar in args]


def get_model_mapping(source_model, target_model):
    model_mapping = {}
    if isinstance(target_model, dict):
        target_variables = list(target_model.keys())
    else:
        target_variables = target_model._flatten()
    for p_var in target_variables:
        try:
            model_mapping.update({source_model.get_variable(p_var.name): p_var})
        except KeyError:
            pass
    return model_mapping


def reassign_samples(samples, model_mapping=(), source_model=(), target_model=()):
    out_sample = {}
    if model_mapping:
        pass
    elif source_model and target_model:
        model_mapping = get_model_mapping(source_model, target_model)
    else:
        raise ValueError("Either a model mapping or both source and target models have to be provided as input")
    for key, value in samples.items():
        try:
            out_sample.update({model_mapping[key]: value})
        except KeyError:
            pass
    return out_sample


def reject_samples(samples, model_statistics, truncation_rule):
    decision_variable = model_statistics(samples)
    sample_indices = [index for index, value in enumerate(decision_variable) if truncation_rule(value)]
    num_accepted_samples = len(sample_indices)
    if num_accepted_samples == 0:
        return None, 0, 0.001 #TODO: Improve
    else:
        remaining_samples = {var: value[sample_indices, :] for var, value in samples.items()}

        acceptance_probability = num_accepted_samples/float(decision_variable.shape[0])
        return remaining_samples, num_accepted_samples, acceptance_probability


def concatenate_samples(samples_list):
    ''' replaced with torch.cat'''
    if len(samples_list) == 1:
        return samples_list[0]
    else:
        #num_samples = len(samples_list)
        paired_list = zip_dict_list(samples_list)
        samples = {var: torch.cat(tensor_tuple, dim=0)#.contiguous().view(tuple([num_samples]) + tuple(tensor_tuple[0].shape[1:]))
                   for var, tensor_tuple in paired_list.items()}
        return samples


def tensor_range(tensor):
    return set(np.ndarray.tolist(tensor.detach().numpy().flatten()))


def batch_meshgrid(tensor1, tensor2):
    tensor1_shape = tensor1.shape
    tensor2_shape = tensor2.shape
    new_shape = [tensor1_shape[0], tensor1_shape[1], tensor2_shape[1]]

    assert (len(tensor1_shape) == 2 and len(tensor2_shape) == 2), "You can use batch_meshgrid only on 2D tensor (The first dimension is the batch dimension)"

    tensor1 = tensor1.unsqueeze(dim=2).expand(*new_shape)
    tensor2 = tensor2.unsqueeze(dim=1).expand(*new_shape)
    return tensor1, tensor2


def get_tensor_data(tensor):
    return tensor.cpu().detach().numpy()


def delta(x, y):
    return (x == y).float()


# def is_integer_string(string, signed=True):
#     digits = {str(digit) for digit in set(range(10))}
#     if signed and (string[0] == "+" or string[0] == "-"):
#         string = string[1:]
#     return set(string).issubset(digits)
#
#
# def is_number_string(string):
#     special_count = max(string.count("."), string.count("e"))
#     if not (0 <= special_count < 2):
#         return False
#     digits = {str(digit) for digit in set(range(10))}
#     digits.add(".")
#     mantissa, exponent = string.split("e")
#     if not is_integer_string(exponent):
#         return False
#     else:
#         integer, decimals = mantissa.split(".")
#         return is_integer_string(integer) and is_integer_string(decimals, signed=False)


def get_numerical_index_from_string(string):
    try:
        if not string:
            return string
        else:
            float(string)
            return string
    except ValueError:
        return get_numerical_index_from_string(string[1:])