"""Utility functions
"""
import inspect
import textwrap
from collections import namedtuple
import re
import os

SPLITS = ('train', 'val', 'test')


class VergeMLError(Exception):
    """System error.
    """

    def __init__(self, # pylint: disable=R0913
                 message, suggestion=None, help_topic=None, hint_type=None, hint_key=None):
        super().__init__(message)
        self.suggestion = suggestion
        self.message = message
        self.hint_type = hint_type
        self.hint_key = hint_key
        self.help_topic = help_topic

    def __str__(self):
        if self.suggestion:
            if len(self.message + self.suggestion) < 80:
                return self.message + " " + self.suggestion
            return self.message + "\n" + self.suggestion
        return self.message


def wrap_text(text):
    """Wrap text to be readable in the terminal.
    """
    # TODO check terminal width
    res = []
    for para in text.split("\n\n"):
        if para.splitlines()[0].strip().endswith(":"):
            res.append(para)
        else:
            res.append(textwrap.fill(para, drop_whitespace=True, fix_sentence_endings=True))
    return "\n\n".join(res)


_Intro = namedtuple('_Intro', ['args', 'defaults', 'types'])


def introspect(call):
    """Introspect a function call.
    """
    spec = inspect.getfullargspec(call)
    args = spec.args
    defaults = dict(zip(reversed(spec.args), reversed(spec.defaults or [])))
    types = spec.annotations
    return _Intro(args, defaults, types)


# taken from here: https://www.python-course.eu/levenshtein_distance.php
def _iterative_levenshtein(source, targ):
    """
        iterative_levenshtein(source, targ) -> ldist
        ldist is the Levenshtein distance between the strings
        source and targ.
        For all i and j, dist[i,j] will contain the Levenshtein
        distance between the first i characters of source and the
        first j characters of targ
    """
    rows = len(source)+1
    cols = len(targ)+1
    dist = [[0 for x in range(cols)] for x in range(rows)]
    # source prefixes can be transformed into empty strings
    # by deletions:
    for i in range(1, rows):
        dist[i][0] = i
    # target prefixes can be created from an empty source string
    # by inserting the characters
    for i in range(1, cols):
        dist[0][i] = i

    row, col = None, None

    for col in range(1, cols):
        for row in range(1, rows):
            if source[row-1] == targ[col-1]:
                cost = 0
            else:
                cost = 1
            dist[row][col] = min(dist[row-1][col] + 1,      # deletion
                                 dist[row][col-1] + 1,      # insertion
                                 dist[row-1][col-1] + cost) # substitution

    assert row and col
    return dist[row][col]


def did_you_mean(candidates, value, fmt="'{}'"):
    """In case of a misspelling, return possible candidates.
    """
    candidates = list(candidates)
    names = list(sorted(map(lambda n: (_iterative_levenshtein(value, n), n), candidates)))
    names = list(filter(lambda dn: dn[0] <= 2, names))
    return 'Did you mean ' + fmt.format(names[0][1]) + '?' if names else None


def dict_set_path(dic, path, value):
    """Set the value of a dict using path syntax.
    """
    cur = dic
    path = path.split(".")
    for key in path[:-1]:
        cur = cur.setdefault(key, {})
    cur[path[-1]] = value

def dict_del_path(dic, path):
    """Delete a value from a dict using path syntax.
    """
    if isinstance(path, str):
        path = path.split(".")
    if len(path) == 1:
        del[dic[path[0]]]
    else:
        pat, *rest = path
        dict_del_path(dic[pat], rest)
        if not dic[pat]:
            del dic[pat]

def dict_has_path(dic, path):
    """Check if a dict contains a value using path syntax.
    """
    cur = dic
    for pat in path.split("."):
        if isinstance(cur, dict) and pat in cur:
            cur = cur[pat]
        else:
            return False
    return True

_DEFAULT = object()
def dict_get_path(dic, path, default=_DEFAULT):
    """Get the value of a dict using path syntax.
    """
    cur = dic
    for pat in path.split("."):
        if isinstance(cur, dict) and pat in cur:
            cur = cur[pat]
        elif default != _DEFAULT:
            return default
        else:
            raise KeyError(path)
    return cur

def dict_merge(dict1, dict2):
    """Merge two dicts.
    """
    if not isinstance(dict1, dict) or not isinstance(dict2, dict):
        return dict2
    for k in dict2:
        if k in dict1:
            dict1[k] = dict_merge(dict1[k], dict2[k])
        else:
            dict1[k] = dict2[k]
    return dict1

def dict_paths(dic, path=None):
    """Get paths in a dict.
    """
    res = []
    if path:
        if not dict_has_path(dic, path):
            return res
        value = dict_get_path(dic, path)
    else:
        value = dic
    if not isinstance(dic, dict):
        return res
    def _collect_path(dic, path):
        for k, val in dic.items():
            npath = f"{path}.{k}" if path is not None else k
            if isinstance(val, dict):
                _collect_path(val, npath)
            else:
                res.append(npath)
    _collect_path(value, path)
    return res


def parse_trained_models(argv):
    """Parse @syntax for specifying trained models on the command line.
    """
    names = []
    for part in argv:
        if re.match("^@[a-zA-Z0-9_-]+$", part):
            names.append(part[1:])
        else:
            break
    rest = argv[len(names):]
    return names, rest

def parse_split(value):
    """Decodes the split value.

    Returns a tuple (type, value) where type is either perc, num or dir set.
    """
    assert isinstance(value, (int, str))

    if isinstance(value, int):
        return ('num', value)
    if value.endswith("%"):
        return ('perc', float(value.rstrip("%").strip()))
    if value.isdigit():
        return ('num', int(value))
    return ('dir', value)

def format_info_text(text, indent=0, width=70):
    """Return text formatted for readability.
    """
    text = text.strip("\n")
    res = []
    for line in text.splitlines():
        if line.startswith("  "):
            res.append(line)
        elif line.strip() == "":
            res.append(line)
        else:
            res.extend(textwrap.wrap(line, width=width-indent))
    if indent:
        indstr = str(' ' * indent)
        res = list(map(lambda l: indstr + l, res))
    return "\n".join(res)

if os.name == 'nt':
    def xlink(src, dst):
        """Cross platform file links.
        """
        os.link(src, dst)
else:
    def xlink(src, dst):
        """Cross platform file links.
        """
        # use symlink on Unix
        os.symlink(src, dst)