#!/usr/bin/env python

"""
Given a dataset and a vocabulary file, filter the recordings which are desired.
"""
# This has to work with preprocess_dataset.py

# Core Library modules
import csv
import logging
import os
import pickle
import sys
from typing import Any, Dict, List, Sequence, Set

# Third party modules
import pkg_resources
import yaml
from natsort import natsorted

logger = logging.getLogger(__name__)


def main(symbol_yml_file: str, raw_pickle_file: str, pickle_dest_path: str):
    """
    Parameters
    ----------
    symbol_yml_file : str
        Path to a YAML file which contains recordings.
    raw_pickle_file : str
        Path to a pickle file which contains raw recordings.
    pickle_dest_path : str
        Path where the filtered dict gets serialized as a pickle file again.
    """
    metadata = get_metadata()
    symbol_ids = get_symbol_ids(symbol_yml_file, metadata)
    symbol_ids = transform_sids(symbol_ids)
    raw = load_raw(raw_pickle_file)
    filter_and_save(raw, symbol_ids, pickle_dest_path)


def get_symbol_ids(
    symbol_yml_file: str, metadata: Dict[Any, Any]
) -> List[Dict[str, Any]]:
    r"""
    Get a list of ids which describe which class they get mapped to.

    Parameters
    ----------
    symbol_yml_file : str
        Path to a YAML file.
    metadata : Dict[Any, Any]
        Metainformation of symbols, like the id on write-math.com.
        Has keys 'symbols', 'tags', 'tags2symbols'.

    Returns
    -------
    symbol_ids : List[Dict[str, Any]]
        Each dictionary represents one output class and has
        to have the keys 'id' (which is an id on write-math.com) and
        'mappings' (which is a list of ids on write-math.com). The mappings
        list should at least contain the id itself, but can contain more.

    Examples
    --------
    >>> from hwrt.utils import get_symbols_filepath
    >>> metadata = {"symbols": [{"formula_in_latex": r"\alpha", "id": 42},
    ...                         {"formula_in_latex": r"\beta", "id": 1337}]}
    >>> out = get_symbol_ids(get_symbols_filepath(testing=True), metadata=metadata)
    >>> len(out)
    2
    >>> out[0]
    {'id': 42, 'formula_in_latex': '\\alpha', 'mappings': [42]}
    >>> out[1]
    {'id': 1337, 'formula_in_latex': '\\beta', 'mappings': [1337]}

    The YAML file has to be of the structure

    ```
    - {latex: 'A'}
    - {latex: 'B'}
    - {latex: 'O',
       mappings: ['0', 'o']}
    - {latex: 'C'}
    - {latex: '::REJECT::',
       mappings: ['::ALL_FREE::']}
    - {latex: '::ARROW::',
       mappings: ['::TAG/arrow::'],
       exclude: ['\rightarrow']}
    ```
    """
    with open(symbol_yml_file) as stream:
        symbol_cfg = yaml.safe_load(stream)
    symbol_ids = []
    symbol_ids_set: Set[str] = set()

    for symbol in symbol_cfg:
        if "latex" not in symbol:
            logger.error(
                "Key 'latex' not found for a symbol in %s (%s)", symbol_yml_file, symbol
            )
            sys.exit(-1)
        results = [
            el
            for el in metadata["symbols"]
            if el["formula_in_latex"] == symbol["latex"]
        ]
        if len(results) != 1:
            logger.warning(
                "Found %i results for %s: %s", len(results), symbol["latex"], results
            )
            if len(results) > 1:
                results = sorted(results, key=lambda n: n["id"])
            else:
                sys.exit(-1)
        mapping_ids = [results[0]["id"]]
        if "mappings" in symbol:
            for msymbol in symbol["mappings"]:
                filtered = [
                    el
                    for el in metadata["symbols"]
                    if el["formula_in_latex"] == msymbol["latex"]
                ]
                if len(filtered) != 1:
                    logger.error(
                        "Found %i results for %s: %s", len(filtered), msymbol, filtered
                    )
                    if len(filtered) > 1:
                        filtered = natsorted(filtered, key=lambda n: n["id"])
                    else:
                        sys.exit(-1)
                mapping_ids.append(filtered[0]["id"])
        symbol_ids.append(
            {
                "id": int(results[0]["id"]),
                "formula_in_latex": results[0]["formula_in_latex"],
                "mappings": mapping_ids,
            }
        )
        for id_tmp in mapping_ids:
            if id_tmp not in symbol_ids_set:
                symbol_ids_set.add(id_tmp)
            else:
                for symbol_tmp in symbol_ids:
                    if id_tmp in symbol_tmp["mappings"]:
                        break
                logger.error("Symbol id %s is already used: %s", id_tmp, symbol_tmp)
                sys.exit(-1)

    # print(metadata.keys())
    # for el in metadata:
    #    print(metadata[el][0].keys())
    # TODO: assert no double mappings
    # TODO: Support for
    # - ::ALL_FREE:: - meaning the rest of all ids which are not assigned to
    #                  any other class get assigned to this class
    # - ::TAG/arrow:: - meaning all ids of the tag arrow get assigned here
    # - exclude
    logger.info(
        "%i base classes and %i write-math ids.", len(symbol_ids), len(symbol_ids_set)
    )
    return symbol_ids


def transform_sids(symbol_ids):
    new_sids = {}
    for to_sid in symbol_ids:
        for from_sid in to_sid["mappings"]:
            new_sids[int(from_sid)] = int(to_sid["id"])
    return new_sids


def get_metadata() -> Dict[str, Any]:
    """
    Get metadata of symbols, like their tags, id on write-math.com, LaTeX
    command and unicode code point.
    """
    misc_path = pkg_resources.resource_filename("hwrt", "misc/")
    wm_symbols = os.path.join(misc_path, "wm_symbols.csv")
    wm_tags = os.path.join(misc_path, "wm_tags.csv")
    wm_tags2symbols = os.path.join(misc_path, "wm_tags2symbols.csv")
    return {
        "symbols": read_csv(wm_symbols),
        "tags": read_csv(wm_tags),
        "tags2symbols": read_csv(wm_tags2symbols),
    }


def read_csv(filepath: str) -> Sequence[Dict[Any, Any]]:
    """
    Read a CSV into a list of dictionarys. The first line of the CSV determines
    the keys of the dictionary.

    Parameters
    ----------
    filepath : str

    Returns
    -------
    symbols : List[Dict]
    """
    symbols = []
    with open(filepath) as csvfile:
        spamreader = csv.DictReader(csvfile, delimiter=",", quotechar='"')
        for row in spamreader:
            symbols.append(row)
    return symbols


def load_raw(raw_pickle_file: str) -> Dict[Any, Any]:
    """
    Load a pickle file of raw recordings.

    Parameters
    ----------
    raw_pickle_file : str
        Path to a pickle file which contains raw recordings.

    Returns
    -------
    raw : Dict[Any, Any]
        The loaded pickle file.
    """
    with open(raw_pickle_file, "rb") as f:
        raw = pickle.load(f)
    logger.info("Loaded %i recordings.", len(raw["handwriting_datasets"]))
    return raw


def filter_and_save(
    raw: Dict[Any, Any], symbol_ids: List[Dict[str, Any]], destination_path: str
):
    """
    Parameters
    ----------
    raw : Dict[Any, Any]
        with key 'handwriting_datasets'
    symbol_ids : Dict[str, Any]
        Maps LaTeX to write-math.com id
    destination_path : str
        Path where the filtered dict 'raw' will be saved
    """
    logger.info("Start filtering...")
    new_hw_ds = []
    for el in raw["handwriting_datasets"]:
        if el["formula_id"] in symbol_ids:
            el["formula_id"] = symbol_ids[el["formula_id"]]
            el["handwriting"].formula_id = symbol_ids[el["formula_id"]]
            new_hw_ds.append(el)
    raw["handwriting_datasets"] = new_hw_ds

    # pickle
    logger.info("Start dumping %i recordings...", len(new_hw_ds))
    pickle.dump(raw, open(destination_path, "wb"), 2)