import json
from functools import partial
from inspect import signature
from operator import itemgetter
from pathlib import Path
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
from fastai.core import to_np
from fastai.learner import Learner, ModelData
from nltk.translate.bleu_score import SmoothingFunction
from nltk.translate.bleu_score import sentence_bleu
from tqdm import tqdm

from quicknlp.data.model_helpers import BatchBeamTokens

States = Union[List[Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]], torch.Tensor]

HParam = Union[List[int], int]


class RandomUniform:

    def __init__(self, numbers=1000000):
        self.numbers = numbers
        self.array = np.random.rand(numbers)
        self.count = 0

    def __call__(self, *args, **kwargs):
        if self.count >= self.array.size:
            self.count = 0
            self.array = np.random.rand(self.numbers)
        rand = self.array[self.count]
        self.count += 1
        return rand


def concat_layer_bidir_state(states: States, bidir):
    if isinstance(states, (list, tuple)) and bidir: # lstm case
        return (states[0].transpose(1, 0).contiguous().view(1, -1, 2 * states[0].size(-1)),
                states[1].transpose(1, 0).contiguous().view(1, -1, 2 * states[1].size(-1)))
    elif bidir: # gru case
        return states.transpose(1, 0).contiguous().view(1, -1, 2 * states[0].size(-1))
    else:
        return states


def concat_bidir_state(states: States, bidir: bool, cell_type: str, nlayers: int) -> States:
    if isinstance(states, list):
        state = []
        for index in range(len(states)):
            state.append(concat_layer_bidir_state(states[index], bidir=bidir))
    else:
        state = concat_layer_bidir_state(states, bidir=bidir)
    return state


def print_dialogue_features(modeldata: ModelData, num_batches: int, num_sentences: int):
    inputs, responses, targets = [], [], []
    for *x, y in iter(modeldata.trn_dl):
        inputs.append(to_np(x[0]))
        responses.append(to_np(x[1]))
        targets.append(to_np(y))
    for batch_num, (input, response, target) in enumerate(zip(inputs, responses, targets)):
        input = np.transpose(input, [1, 2, 0])  # transpose number of utterances to beams [sl, bs, nb]
        inputs_str = modeldata.itos(input, "text")
        inputs_str = ["\n".join(conv) for conv in inputs_str]
        targets_str = modeldata.itos(target, "text")
        response_str = modeldata.itos(response, "text")
        for index, (inp, resp, targ) in enumerate(zip(inputs_str, response_str, targets_str)):
            print(
                f'BATCH: {batch_num} SAMPLE : {index}\nINPUT:\n{"".join(inp)}, {len(inp.split())}\nRESPONSE:\n{"".join(resp)}, {len(resp[0].split())}\nTARGET:\n{ "".join(targ)}, {len(targ[0].split())}\n\n')
            if 0 < num_sentences <= index - 1:
                break
        if 0 < num_batches <= batch_num - 1:
            break


def print_features(modeldata: ModelData, num_batches=1, num_sentences=-1):
    inputs, responses, targets = [], [], []
    for *x, y in iter(modeldata.trn_dl):
        inputs.append(to_np(x[0]))
        responses.append(to_np(x[1]))
        targets.append(to_np(y))
    for batch_num, (input, target, response) in enumerate(zip(inputs, targets, responses)):
        inputs_str: BatchBeamTokens = modeldata.itos(input, modeldata.trn_dl.source_names[0])
        response_str: BatchBeamTokens = modeldata.itos(response, modeldata.trn_dl.source_names[1])
        targets_str: BatchBeamTokens = modeldata.itos(target, modeldata.trn_dl.target_names[0])
        for index, (inp, targ, resp) in enumerate(zip(inputs_str, targets_str, response_str)):
            print(
                f'batch: {batch_num} sample : {index}\ninput: {" ".join(inp)}\ntarget: { " ".join(targ)}\nresponse: {" ".join(resp)}\n\n')
            if 0 < num_sentences <= index - 1:
                break
        if 0 < num_batches <= batch_num - 1:
            break


def print_batch(learner: Learner, modeldata: ModelData, input_field, output_field, num_batches=1, num_sentences=-1,
                is_test=False, num_beams=1, weights=None, smoothing_function=None):
    predictions, targets, inputs = learner.predict_with_targs_and_inputs(is_test=is_test, num_beams=num_beams)
    weights = (1 / 3., 1 / 3., 1 / 3.) if weights is None else weights
    smoothing_function = SmoothingFunction().method1 if smoothing_function is None else smoothing_function
    blue_scores = []
    for batch_num, (input, target, prediction) in enumerate(zip(inputs, targets, predictions)):
        inputs_str: BatchBeamTokens = modeldata.itos(input, input_field)
        predictions_str: BatchBeamTokens = modeldata.itos(prediction, output_field)
        targets_str: BatchBeamTokens = modeldata.itos(target, output_field)
        for index, (inp, targ, pred) in enumerate(zip(inputs_str, targets_str, predictions_str)):
            blue_score = sentence_bleu([targ], pred, smoothing_function=smoothing_function, weights=weights)
            print(
                f'batch: {batch_num} sample : {index}\ninput: {" ".join(inp)}\ntarget: { " ".join(targ)}\nprediction: {" ".join(pred)}\nbleu: {blue_score}\n\n')
            blue_scores.append(blue_score)
            if 0 < num_sentences <= index - 1:
                break
        if 0 < num_batches <= batch_num - 1:
            break
    print(f'mean bleu score: {np.mean(blue_scores)}')


def print_dialogue_batch(learner: Learner, modeldata: ModelData, input_field, output_field, num_batches=1,
                         num_sentences=-1, is_test=False,
                         num_beams=1, smoothing_function=None, weights=None):
    weights = (1 / 3., 1 / 3., 1 / 3.) if weights is None else weights
    smoothing_function = SmoothingFunction().method1 if smoothing_function is None else smoothing_function
    predictions, targets, inputs = learner.predict_with_targs_and_inputs(is_test=is_test, num_beams=num_beams)
    blue_scores = []
    for batch_num, (input, target, prediction) in enumerate(zip(inputs, targets, predictions)):
        input = np.transpose(input, [1, 2, 0])  # transpose number of utterances to beams [sl, bs, nb]
        inputs_str: BatchBeamTokens = modeldata.itos(input, input_field)
        inputs_str: List[str] = ["\n".join(conv) for conv in inputs_str]
        predictions_str: BatchBeamTokens = modeldata.itos(prediction, output_field)
        targets_str: BatchBeamTokens = modeldata.itos(target, output_field)
        for index, (inp, targ, pred) in enumerate(zip(inputs_str, targets_str, predictions_str)):
            if targ[0].split() == pred[0].split()[1:]:
                blue_score = 1
            else:
                blue_score = sentence_bleu([targ[0].split()], pred[0].split()[1:],
                                           smoothing_function=smoothing_function,
                                           weights=weights
                                           )
            print(
                f'BATCH: {batch_num} SAMPLE : {index}\nINPUT:\n{"".join(inp)}\nTARGET:\n{ "".join(targ)}\nPREDICTON:\n{"".join(pred)}\nblue: {blue_score}\n\n')
            blue_scores.append(blue_score)
            if 0 < num_sentences <= index - 1:
                break
        if 0 < num_batches <= batch_num - 1:
            break
    print(f'bleu score: mean: {np.mean(blue_scores)}, std: {np.std(blue_scores)}')


def get_trainable_parameters(model: nn.Module, grad=False) -> List[str]:
    if grad:
        return [name for name, param in model.named_parameters() if
                param.grad is not None and param.requires_grad is True]
    else:
        return [name for name, param in model.named_parameters() if param.requires_grad is True]


def get_list(value: Union[List[Any], Any], multiplier: int = 1) -> List[Any]:
    if isinstance(value, list):
        assert len(value) == multiplier, f"{value} is not the correct size {multiplier}"
    else:
        value = [value] * multiplier
    return value


Array = Union[np.ndarray, torch.Tensor, int, float]


def assert_dims(value: Sequence[Array], dims: List[Optional[int]]) -> Sequence[Array]:
    """Given a nested sequence, with possibly torch or nympy tensors inside, assert it agrees with the
        dims provided

    Args:
        value (Sequence[Array]): A sequence of sequences with potentially arrays inside
        dims (List[Optional[int]]: A list with the expected dims. None is used if the dim size can be anything

    Raises:
        AssertionError if the value does not comply with the dims provided
    """
    if isinstance(value, list):
        if dims[0] is not None:
            assert len(value) == dims[0], f'{value} does not match {dims}'
            for row in value:
                assert_dims(row, dims[1:])
    # support for collections with a shape variable, e.g. torch.Tensor, np.ndarray, Variable
    elif hasattr(value, "shape"):
        shape = value.shape
        assert len(shape) == len(dims), f'{shape} does not match {dims}'
        for actual_dim, expected_dim in zip(shape, dims):
            if expected_dim is not None:
                if isinstance(expected_dim, tuple):
                    assert actual_dim in expected_dim, f'{shape} does not match {dims}'
                else:
                    assert actual_dim == expected_dim, f'{shape} does not match {dims}'
    return value


def get_kwarg(kwargs, name, default_value=None, remove=True):
    """Returns the value for the parameter if it exists in the kwargs otherwise the default value provided"""
    if remove:
        value = kwargs.pop(name) if name in kwargs else default_value
    else:
        value = kwargs.get(name, default_value)
    return value


def call_with_signature(callable_fn: Callable, *args, **kwargs):
    new_kwargs = {}
    sig = signature(callable_fn)
    for param in sig.parameters.values():
        if param.name in kwargs:
            new_kwargs[param.name] = kwargs[param.name]
    return callable_fn(*args, **new_kwargs)


def get_pairs_from_dialogues(path_dir, utterance_key, sort_key, role_key, text_key, response_role):
    for file_index, file in enumerate(path_dir.glob("*.json")):
        with file.open('r', encoding='utf-8') as fh:
            dialogues = json.load(fh)
        for dialogue in tqdm(dialogues, desc=f'processed file {file}'):
            if isinstance(sort_key, str):
                key = itemgetter(sort_key)
            elif callable(sort_key):
                key = sort_key
            else:
                raise ValueError("Invalid sort_key provided")
            conversation = sorted(dialogue[utterance_key], key=key)
            text = ""
            for utterance in conversation:
                conv_role = "__" + utterance[role_key] + "__"
                text_with_role = conv_role + " " + utterance[text_key]
                if text != "" and utterance[role_key] == response_role:
                    yield dict(context=text, response=text_with_role)
                text += " " + text_with_role


def save_pairs_to_tsv(pairs, filename):
    filename = Path(filename)
    assert filename.name.endswith(".tsv")
    filename.parent.mkdir(exist_ok=True, parents=True)
    with filename.open('w', encoding='utf-8') as fh:
        for pair in pairs:
            fh.write("{}\t{}\n".format(pair['context'], pair['response']))


def convert_dialogues_to_pairs(path_dir, output_dir, utterance_key, sort_key, role_key, text_key, response_role,
                               train_path=None, validation_path=None, test_path=None):
    path_dir = Path(path_dir)
    iter_func = partial(get_pairs_from_dialogues, utterance_key=utterance_key, sort_key=sort_key,
                        role_key=role_key, text_key=text_key, response_role=response_role)

    def convert_data(folder):
        if folder is not None:
            input_path = path_dir / folder
            save_pairs_to_tsv(iter_func(input_path), output_dir / folder / "dialogues.tsv")

    convert_data(train_path)
    convert_data(validation_path)
    convert_data(test_path)