"""List command.
"""
from copy import deepcopy
from collections import OrderedDict
import os
import os.path
import json
import csv
from datetime import datetime
import io
import sys

import yaml

from vergeml.command import command, CommandPlugin, Command
from vergeml.option import option
from vergeml.utils import VergeMLError
from vergeml.display import DISPLAY
from vergeml.config import parse_command

EXAMPLES = """
$ ml list -sacc
# sort by acc value

$ ml list status -eq RUNNING
# show trainings that are currently running

$ ml list test_acc -gt 0.8
# show AIs with a test accuracy that is greater than 0.8

# available comparison operations:
# -gt, -lt, -eq, -neq, -gte and -lte
""".strip()

@command('list', descr="List trained models.", free_form=True, examples=EXAMPLES) # pylint: disable=R0903
@option('sort', descr="By which column to sort.", default='created-at', short='s')
@option('order', descr="Sort order.", default='asc', short='o', validate=('asc', 'desc'))
@option('columns', descr="Which columns to show.", type='Optional[Union[str, List[str]]]', short='c')
@option('output', descr="Output format.", default='table', validate=('table', 'csv', 'json'))
class ListCommand(CommandPlugin):
    """List trained models"""

    def __call__(self, args, env):

        # Parse and partition into normal and comparison args.
        args, cargs = _parse_args(args, env)

        # When trainings dir does not exist, print an error and exit
        if not os.path.exists(env.get('trainings-dir')):
            print("No trainings found.", file=sys.stderr)
            return

        info, hyper = _find_trained_models(args, env)

        theader, tdata, left_align = _format_table(args, cargs, info, hyper)

        _output_table(args['output'], theader, tdata, left_align)

def _parse_args(args, env):
    args = args[1]

    comps = []
    for idx, arg in enumerate(args):
        if arg in ('-gt', '-lt', '-eq', '-neq', '-gte', '-lte'):
            start, end = idx - 1, idx + 1
            if start < 0 or end >= len(args):
                raise VergeMLError("Invalid options.", help_topic='list')
            comps.append((start, end))

    cargs = []
    for start, end in reversed(comps):
        cargs.append(args[start:end+1])
        del args[start:end+1]

    cmd = deepcopy(Command.discover(ListCommand))
    cmd.free_form = False
    args.insert(0, 'list')
    args = cmd.parse(args)

    # If existent, read settings from the config file
    config = parse_command(cmd, env.get(cmd.name))

    # Set missing args from the config file
    for k, arg in config.items():
        args.setdefault(k, arg)

    # Set missing args from default
    for opt in cmd.options:
        if opt.name not in args and (opt.default is not None or not opt.is_required()):
            args[opt.name] = opt.default

    return args, cargs

def _find_trained_models(args, env):
    info = {}
    hyper = {}
    train_dir = env.get('trainings-dir')

    for trained_model in os.listdir(train_dir):
        data_yaml = os.path.join(train_dir, trained_model, 'data.yaml')
        if os.path.isfile(data_yaml):
            with open(data_yaml) as file:
                doc = yaml.safe_load(file)
        else:
            doc = {}
        info[trained_model] = {}
        hyper[trained_model] = {}

        if 'model' in doc:
            info[trained_model]['model'] = doc['model']

        if 'results' in doc:
            info[trained_model].update(doc['results'])

        if 'hyperparameters' in doc:
            hyper[trained_model].update(doc['hyperparameters'])

    sort = [s.strip() for s in args['sort'].split(",")]

    info = OrderedDict(sorted(info.items(), reverse=(args['order'] == 'asc'),
                              key=lambda x: [x[1].get(s, 0) for s in sort]))
    return info, hyper

def _format_table(args, cargs, info, hyper): # pylint: disable=R0912
    theader = ['AI', 'model', 'status', 'num-samples', 'training-start', 'epochs']
    exclude = ['training-end', 'steps', 'created-at']

    if args['columns']:
        cols = args['columns']
        if isinstance(cols, str):
            cols = cols.split(",")

        theader = ['AI'] + [s.strip() for s in cols]
        exclude = []

    tdata = []
    left_align = set([0])

    for trained_model, results in info.items():
        rdata = [""] * len(theader)
        rdata[0] = "@" + trained_model

        if not _filter(results, hyper[trained_model], cargs):
            continue

        for k, val in sorted(results.items()):
            if k in exclude and not args['columns']:
                continue

            if not k in theader and not args['columns'] and isinstance(val, (str, int, float)):
                theader.append(k)
                rdata.append(None)

            if k in theader:
                pos = theader.index(k)

                if k in ('training-start', 'training-end', 'created-at'):
                    val = datetime.utcfromtimestamp(val)
                    val = val.strftime("%Y-%m-%d %H:%M")
                elif isinstance(val, float):
                    val = "%.4f" % val
                elif isinstance(val, str):
                    left_align.add(pos)

                rdata[pos] = val

        for k, val in sorted(hyper[trained_model].items()):

            if k in theader:
                pos = theader.index(k)
                if isinstance(val, float):
                    val = "%.4f" % val
                elif isinstance(val, str):
                    left_align.add(pos)

                rdata[pos] = val

        tdata.append(rdata)

    return theader, tdata, left_align

def _output_table(output, theader, tdata, left_align):

    if not tdata:
        print("No matching trained models found.", file=sys.stderr)

    if output == 'table':
        if not tdata:
            return
        tdata.insert(0, theader)
        print(DISPLAY.table(tdata, left_align=left_align).getvalue(fit=True))

    elif output == 'json':
        res = []
        for row in tdata:
            res.append(dict(zip(theader, row)))
        print(json.dumps(res))

    elif output == 'csv':
        buffer = io.StringIO()

        writer = csv.writer(buffer)
        writer.writerow(theader)
        for row in tdata:
            writer.writerow(row)
        val = buffer.getvalue()
        val = val.replace('\r', '')
        print(val.strip())

def _filter(info, hyper, comp_args):
    try:
        cols = {}
        cols.update(hyper)
        cols.update(info)
        res = True
        for col, opr, val in comp_args:

            if not col in cols:
                return False

            cval = cols[col]

            if isinstance(cval, int):
                val = int(val)
            elif isinstance(cval, float):
                val = float(val)

            if opr == '-eq':
                res = res and (cval == val)
            elif opr == '-neq':
                res = res and (cval != val)
            elif opr == '-gt':
                res = res and (cval > val)
            elif opr == '-lt':
                res = res and (cval < val)
            elif opr == '-gte':
                res = res and (cval >= val)
            elif opr == '-lte':
                res = res and (cval <= val)

            if not res:
                return False
        return res
    except: # pylint: disable=W0702
        return False