import argparse
import codecs
import datetime
import functools
import io
import json
import logging
import os
import re
import sys
import threading
import traceback
from typing import Callable, Dict, Sequence, Tuple, Union

import numpy
import xxhash
import yaml

from modelforge.meta import get_datetime_now

logs_are_structured = False


def get_timezone() -> Tuple[datetime.tzinfo, str]:
    """Discover the current time zone and it's standard string representation (for source{d})."""
    dt = get_datetime_now().astimezone()
    tzstr = dt.strftime("%z")
    tzstr = tzstr[:-2] + ":" + tzstr[-2:]
    return dt.tzinfo, tzstr


timezone, tzstr = get_timezone()
_now = get_datetime_now()
if _now.month == 12:
    _fest = "🎅"
elif _now.month == 10 and _now.day > (31 - 7):
    _fest = "🎃"
else:
    _fest = ""
del _now


def format_datetime(dt: datetime.datetime):
    """Represent the date and time in source{d} format."""
    return dt.strftime("%Y-%m-%dT%k:%M:%S.%f000") + tzstr


def reduce_thread_id(thread_id: int) -> str:
    """Make a shorter thread identifier by hashing the original."""
    return xxhash.xxh32(thread_id.to_bytes(8, "little")).hexdigest()[:4]


def with_logger(cls):
    """Add a logger as static attribute to a class."""
    cls._log = logging.getLogger(cls.__name__)
    return cls


trailing_dot_exceptions = set()


def check_trailing_dot(func: Callable) -> Callable:
    """
    Decorate a function to check if the log message ends with a dot.

    AssertionError is raised if so.
    """
    @functools.wraps(func)
    def decorated_with_check_trailing_dot(record: logging.LogRecord):
        if record.name not in trailing_dot_exceptions:
            msg = record.msg
            if isinstance(msg, str) and msg.endswith(".") and not msg.endswith(".."):
                raise AssertionError(
                    "Log message is not allowed to have a trailing dot: %s: \"%s\"" %
                    (record.name, msg))
        return func(record)
    return decorated_with_check_trailing_dot


class NumpyLogRecord(logging.LogRecord):
    """
    LogRecord with the special handling of numpy arrays which shortens the long ones.
    """

    @staticmethod
    def array2string(arr: numpy.ndarray) -> str:
        """Format numpy array as a string."""
        shape = str(arr.shape)[1:-1]
        if shape.endswith(","):
            shape = shape[:-1]
        return numpy.array2string(arr, threshold=11) + "%s[%s]" % (arr.dtype, shape)

    def getMessage(self):
        """
        Return the message for this LogRecord.

        Return the message for this LogRecord after merging any user-supplied \
        arguments with the message.
        """
        if isinstance(self.msg, numpy.ndarray):
            msg = self.array2string(self.msg)
        else:
            msg = str(self.msg)
        if self.args:
            a2s = self.array2string
            if isinstance(self.args, Dict):
                args = {k: (a2s(v) if isinstance(v, numpy.ndarray) else v)
                        for (k, v) in self.args.items()}
            elif isinstance(self.args, Sequence):
                args = tuple((a2s(a) if isinstance(a, numpy.ndarray) else a)
                             for a in self.args)
            else:
                raise TypeError("Unexpected input '%s' with type '%s'" % (self.args,
                                                                          type(self.args)))
            msg = msg % args
        return msg


class AwesomeFormatter(logging.Formatter):
    """
    logging.Formatter which adds colors to messages and shortens thread ids.
    """

    GREEN_MARKERS = [" ok", "ok:", "finished", "complete", "ready",
                     "done", "running", "success", "saved"]
    GREEN_RE = re.compile("|".join(GREEN_MARKERS))

    def formatMessage(self, record: logging.LogRecord) -> str:
        """Convert the already filled log record to a string."""
        level_color = "0"
        text_color = "0"
        fmt = ""
        if record.levelno <= logging.DEBUG:
            fmt = "\033[0;37m" + logging.BASIC_FORMAT + "\033[0m"
        elif record.levelno <= logging.INFO:
            level_color = "1;36"
            lmsg = record.message.lower()
            if self.GREEN_RE.search(lmsg):
                text_color = "1;32"
        elif record.levelno <= logging.WARNING:
            level_color = "1;33"
        elif record.levelno <= logging.CRITICAL:
            level_color = "1;31"
        if not fmt:
            fmt = "\033[" + level_color + \
                  "m%(levelname)s\033[0m:%(rthread)s:%(name)s:\033[" + text_color + \
                  "m%(message)s\033[0m"
        fmt = _fest + fmt
        record.rthread = reduce_thread_id(record.thread)
        return fmt % record.__dict__


class StructuredHandler(logging.Handler):
    """logging handler for structured logging."""

    def __init__(self, level=logging.NOTSET):
        """Initialize a new StructuredHandler."""
        super().__init__(level)
        self.local = threading.local()

    @check_trailing_dot
    def emit(self, record: logging.LogRecord):
        """Print the log record formatted as JSON to stdout."""
        created = datetime.datetime.fromtimestamp(record.created, timezone)
        obj = {
            "level": record.levelname.lower(),
            "msg": record.msg % record.args,
            "source": "%s:%d" % (record.filename, record.lineno),
            "time": format_datetime(created),
            "thread": reduce_thread_id(record.thread),
        }
        if record.exc_info is not None:
            obj["error"] = traceback.format_exception(*record.exc_info)[1:]
        try:
            obj["context"] = self.local.context
        except AttributeError:
            pass
        json.dump(obj, sys.stdout, sort_keys=True)
        sys.stdout.write("\n")
        sys.stdout.flush()

    def flush(self):
        """Write all pending text to stdout."""
        sys.stdout.flush()


def setup(level: Union[str, int], structured: bool, config_path: str = None):
    """
    Make stdout and stderr unicode friendly in case of misconfigured \
    environments, initializes the logging, structured logging and \
    enables colored logs if it is appropriate.

    :param level: The global logging level.
    :param structured: Output JSON logs to stdout.
    :param config_path: Path to a yaml file that configures the level of output of the loggers. \
                        Root logger level is set through the level argument and will override any \
                        root configuration found in the conf file.
    :return: None
    """
    global logs_are_structured
    logs_are_structured = structured

    if not isinstance(level, int):
        level = logging._nameToLevel[level]

    def ensure_utf8_stream(stream):
        if not isinstance(stream, io.StringIO) and hasattr(stream, "buffer"):
            stream = codecs.getwriter("utf-8")(stream.buffer)
            stream.encoding = "utf-8"
        return stream

    sys.stdout, sys.stderr = (ensure_utf8_stream(s)
                              for s in (sys.stdout, sys.stderr))

    # basicConfig is only called to make sure there is at least one handler for the root logger.
    # All the output level setting is down right afterwards.
    logging.basicConfig()
    logging.setLogRecordFactory(NumpyLogRecord)
    if config_path is not None and os.path.isfile(config_path):
        with open(config_path) as fh:
            config = yaml.safe_load(fh)
        for key, val in config.items():
            logging.getLogger(key).setLevel(logging._nameToLevel.get(val, level))
    root = logging.getLogger()
    root.setLevel(level)

    if not structured:
        handler = root.handlers[0]
        handler.emit = check_trailing_dot(handler.emit)
        if not sys.stdin.closed and sys.stdout.isatty():
            handler.setFormatter(AwesomeFormatter())
    else:
        root.handlers[0] = StructuredHandler(level)


def set_context(context):
    """Assign the logging context - an abstract object - to the current thread."""
    try:
        handler = logging.getLogger().handlers[0]
    except IndexError:
        # logging is not initialized
        return
    if not isinstance(handler, StructuredHandler):
        return
    handler.acquire()
    try:
        handler.local.context = context
    finally:
        handler.release()


def add_logging_args(parser: argparse.ArgumentParser, patch: bool = True,
                     erase_args: bool = True) -> None:
    """
    Add command line flags specific to logging.

    :param parser: `argparse` parser where to add new flags.
    :param erase_args: Automatically remove logging-related flags from parsed args.
    :param patch: Patch parse_args() to automatically setup logging.
    """
    parser.add_argument("--log-level", default="INFO", choices=logging._nameToLevel,
                        help="Logging verbosity.")
    parser.add_argument("--log-structured", action="store_true",
                        help="Enable structured logging (JSON record per line).")
    parser.add_argument("--log-config",
                        help="Path to the file which sets individual log levels of domains.")
    # monkey-patch parse_args()
    # custom actions do not work, unfortunately, because they are not invoked if
    # the corresponding --flags are not specified

    def _patched_parse_args(args=None, namespace=None) -> argparse.Namespace:
        args = parser._original_parse_args(args, namespace)
        setup(args.log_level, args.log_structured, args.log_config)
        if erase_args:
            for log_arg in ("log_level", "log_structured", "log_config"):
                delattr(args, log_arg)
        return args

    if patch and not hasattr(parser, "_original_parse_args"):
        parser._original_parse_args = parser.parse_args
        parser.parse_args = _patched_parse_args