#! /usr/bin/env python3

from typing import *

T = TypeVar('T')

import csv
from io import TextIOBase
import itertools
from itertools import combinations, permutations
import regex as re
import collections

from regexorder import RegexOrder
from toolz.itertoolz import isdistinct
from toolz.utils import no_default

from .seqtools import iter2seq, mergeseqs
from .settools import dropsupersets
from .__join import join as seqjoin

Table = Iterable[Union[List[T], Tuple[T, ...]]]

def transpose(data: Table) -> Table:
    for col in zip(*data):
        yield tuple(col)


def loadcsv(path: Union[Iterable[str], str, TextIOBase], delimiter: str = ',') -> Table:
    if isinstance(path, collections.Iterable):
        f = path
    else:
        f = cast(TextIOBase, path if isinstance(path, TextIOBase) else open(path, 'r', newline=''))

    yield from csv.reader(f, delimiter=delimiter)


def dumpcsv(path: Union[str, TextIOBase], data: Table, delimiter: str = ',') -> None:
    f = cast(TextIOBase, path if isinstance(path, TextIOBase) else open(path, 'w', newline=''))

    writer = csv.writer(f, delimiter=delimiter)
    for row in data:
        writer.writerow(row)


def mergecols(cols: Table, default=None, blank=None) -> Optional[Iterable[T]]:
    return mergeseqs(
        cols,
        default=default,
        key=lambda val: val is not None and str(val).strip(blank) != ""
    )


def sortedbycol(data: Table, key: Callable[[Iterable[T]], Any]) -> Table:
    return transpose(sorted(transpose(data), key=key))


def filterbycol(data: Table, key: Callable[[Iterable[T]], Any]) -> Table:
    return transpose(filter(key, transpose(data)))


def trim(table: Table, blank=None) -> Table:
    def isempty(v):
        return v is None or str(v).strip(blank) == ""

    table = iter2seq(table)

    nonemptyflags = [
        any(not isempty(v) for v in col)
        for col in transpose(table)
    ]

    for row in table:
        if all(isempty(v) for v in row):
            continue

        yield list(itertools.compress(row, nonemptyflags))


def parse(lines: Iterable[str], sep=None, useregex=False) -> Table:
    if useregex:
        r = re.compile(sep) if isinstance(sep, str) else sep

        for line in lines:
            yield r.split(line)
    else:
        for line in lines:
            yield line.split(sep)


def parsebymarkdown(text: str) -> Table:
    for row in trim(
            parse(
                filter(lambda line: line, text.split('\n')),
                sep=r"(?<!\\)\|",
                useregex=True
            ),
            blank=" \t-:"
        ):
        yield list(map(str.strip, row))


def parsebyregex(lines: Iterable[str], regex: Any) -> Table:
    r = re.compile(regex) if isinstance(regex, str) else regex

    for line in lines:
        yield r.fullmatch(line).groups(default="")


def parsebyregexes(lines: Iterable[str], regexes: Any) -> Table:
    regexes = [
        re.compile(regex) if isinstance(regex, str) else regex
        for regex in regexes
    ]

    for line in lines:
        vals = [None] * len(regexes)

        start = 0
        for i, regex in enumerate(regexes):
            m = regex.search(line, start)
            vals[i] = m.group(0)
            start = m.end()

        yield vals


def inferschema(data: Table) -> Tuple[str, ...]:
    r = RegexOrder()

    return tuple(r.matchall(col).name for col in transpose(data))


def hasheader(data: Table) -> float:
    r = RegexOrder()

    cols = list(transpose(data))

    total = 0

    for col in cols:
        t1 = r.matchall(col[1:])

        t2 = r.match(col[0], t1)

        if t1 != t2:
            total += 1

    return total / len(cols)


def candidatekeys(data: Table, maxcols: int = 1) -> Iterable[Tuple[int, ...]]:
    cols = list(transpose(data))

    return map(tuple, dropsupersets(map(set, (
        localcolids
        for i in range(1, maxcols + 1)
        for localcolids in combinations(range(len(cols)), i)
        if isdistinct(transpose(cols[j] for j in localcolids))
    ))))


def foreignkeys(primarydata: Table, primarykey: Tuple[int, ...], foreigndata: Table) -> Iterable[Tuple[int, ...]]:
    pvals = set(
        tuple(row[j] for j in primarykey)
        for row in primarydata
    )

    fcols = list(transpose(foreigndata))

    return (
        localcolids
        for localcolids in permutations(range(len(fcols)), len(primarykey))
        if set(transpose(fcols[j] for j in localcolids)) <= pvals
    )


def join(
        lefttable, righttable,
        leftkey, rightkey,
        leftjoin=False, rightjoin=False
    ):
    lefttable = iter2seq(lefttable)
    righttable = iter2seq(righttable)

    return seqjoin(
        lefttable, righttable,
        lambda r: tuple(r[i] for i in leftkey), lambda r: tuple(r[i] for i in rightkey),
        leftdefault=([None] * len(lefttable[0]) if rightjoin else no_default),
        rightdefault=([None] * len(righttable[0]) if leftjoin else no_default)
    )