#!/usr/bin/env python3

# Author: Jan-Thorsten Peter <peter@cs.rwth-aachen.de>
import collections
from sisyphus.hash import *
import inspect
import os
import platform
import sys
import shutil
import time
import logging
import subprocess
import linecache
from typing import Set, Any

try:
    import tracemalloc
except ModuleNotFoundError:
    tracemalloc = None

import sisyphus.global_settings as gs
from sisyphus.block import Block


def get_system_informations(file=sys.stdout):
    print("Uname:", platform.uname(), file=file)
    print("Load:", os.getloadavg(), file=file)


def str_to_GB(m):
    """ Takes a string with size units and converts it into float as GB.
    If only a number is given it assumes it's gigabytes

    :param m:
    :return:
    """
    try:
        m = float(m)
    except ValueError:
        if m[-1] == 'T':
            m = float(m[:-1]) * 1024
        if m[-1] == 'G':
            m = float(m[:-1])
        elif m[-1] == 'M':
            m = float(m[:-1]) / 1024.
        elif m[-1] == 'K':
            m = float(m[:-1]) / 1024. / 1024.
        else:
            assert(False)
    return m


def str_to_hours(t):
    """ Takes a string and converts it into seconds
    If only a number is given it assumes it's hours

    :param m:
    :return:
    """
    try:
        t = float(t)
    except ValueError:
        t = t.split(':')
        assert(len(t) == 3)
        t = int(t[0]) * 3600 + int(t[1]) * 60 + int(t[2])
        t /= 3600.0
    return t


def extract_paths(args: Any) -> Set:
    """
    Extract all path objects from the given arguments.

    :rtype: set
    """
    out = set()
    if isinstance(args, Block):
        return out
    if hasattr(args, '_sis_path') and args._sis_path is True:
        out = {args}
    elif isinstance(args, (list, tuple, set)):
        for a in args:
            out = out.union(extract_paths(a))
    elif isinstance(args, dict):
        for k, v in args.items():
            if not type(k) == str or not k.startswith('_sis_'):
                out = out.union(extract_paths(v))
    elif hasattr(args, '__sis_state__') and not inspect.isclass(args):
        out = out.union(extract_paths(args.__sis_state__()))
    elif hasattr(args, '__getstate__') and not inspect.isclass(args):
        out = out.union(extract_paths(args.__getstate__()))
    elif hasattr(args, '__dict__'):
        for k, v in args.__dict__.items():
            if not type(k) == str or not k.startswith('_sis_'):
                out = out.union(extract_paths(v))
    elif hasattr(args, '__slots__'):
        for k in args.__slots__:
            if hasattr(args, k) and not k.startswith('_sis_'):
                a = getattr(args, k)
                out = out.union(extract_paths(a))
    return out


def sis_hash(obj):
    """
    Takes most object and tries to convert the current state into a hash.

    :param object obj:
    :rtype: str
    """
    return gs.SIS_HASH(obj)


def try_get(v):
    """ Tries to call the get method, if an attribute error is raise return the original value.
    Useful to convert a sisyphus path or variable into the stored object
    """

    try:
        return v.get()
    except AttributeError:
        return v


class execute_in_dir(object):

    """ Object to be used by the with statement.
    All code after the with will be executed in the given directory,
    working directory will be changed back after with statement.
    e.g.:

    cwd = os.getcwd()
    with execute_in_dir('foo'):
        assert(os.path.join(cwd, 'foo') == os.getcwd())
    assert(cwd) == os.getcwd())
    """

    def __init__(self, workdir):
        self.workdir = workdir

    def __enter__(self):
        self.base_dir = os.getcwd()
        os.chdir(self.workdir)

    def __exit__(self, type, value, traceback):
        os.chdir(self.base_dir)


class cache_result(object):

    """ decorated to cache the result of a function for x_seconds """

    def __init__(self, cache_time=30, force_update=None, clear_cache=None):
        self.cache = {}
        self.time = collections.defaultdict(int)
        self.cache_time = cache_time
        self.force_update = force_update
        self.clear_cache = clear_cache

    def __call__(self, f):
        def cache_f(*args, **kwargs):
            # if clear_cache is given as input parameter clear cache and return
            if self.clear_cache and self.clear_cache in kwargs:
                self.cache = {}
                return

            update = False
            if self.force_update and kwargs.get(self.force_update, False):
                del kwargs[self.force_update]
                update = True

            key = (f, args, kwargs)
            # to make it usable as a hash value
            # if a possible hit is missed we just lose the caching effect
            # which shouldn't happen that often
            key = str(key)

            if not update and time.time() - self.time[key] > self.cache_time or key not in self.cache:
                update = True

            if update:
                ret = f(*args, **kwargs)
                self.cache[key] = ret
                self.time[key] = time.time()
            else:
                ret = self.cache[key]
            return ret
        return cache_f


def sh(command,
       capture_output=False,
       pipefail=True,
       executable=None,
       except_return_codes=(0,),
       sis_quiet=False,
       sis_replace={},
       include_stderr=False,
       **kwargs):
    """ Calls a external shell and
    replaces {args} with job inputs, outputs, args
    and executes the command """

    replace = {}
    replace.update(sis_replace)
    replace.update(kwargs)

    command = command.format(**replace)
    if capture_output:
        msg = "Run in Shell (capture output): %s"
    else:
        msg = "Run in Shell: %s"
    msg = msg % command
    if not sis_quiet:
        logging.info(msg)
    sys.stdout.flush()
    sys.stderr.flush()

    if executable is None:
        executable = '/bin/bash'
        if pipefail:
            # this ensures that the job will fail if any part inside of a pipe fails
            command = 'set -ueo pipefail && ' + command

    try:
        if capture_output:
            return subprocess.check_output(command, shell=True, executable=executable,
                                           stderr=subprocess.STDOUT if include_stderr else None).decode()
        else:
            subprocess.check_call(command, shell=True, executable=executable)
    except subprocess.CalledProcessError as e:
        if e.returncode not in except_return_codes:
            raise
        elif capture_output:
            return e.output


def hardlink_or_copy(src, dst, use_symlink_instead_of_copy=False):
    """ Emulate coping of directories by using hardlinks, if hardlink fails copy file.
    Recursively creates new directories and creates hardlinks of all source files into these directories
    if linking files copy file.

    :param src:
    :param dst:
    :return:
    """

    for dirpath, dirnames, filenames in os.walk(src):
        # get relative path to given to source directory
        relpath = dirpath[len(src) + 1:]

        # create directory if it doesn't exist
        try:
            os.mkdir(os.path.join(dst, relpath))
        except FileExistsError:
            assert os.path.isdir(os.path.join(dst, relpath))

        # create subdirectories
        for dirname in dirnames:
            try:
                os.mkdir(os.path.join(dst, relpath, dirname))
            except FileExistsError:
                assert os.path.isdir(os.path.join(dst, relpath))

        # link or copy files
        for filename in filenames:
            src_file = os.path.join(dirpath, filename)
            dst_file = os.path.join(dst, relpath, filename)
            try:
                os.link(src_file, dst_file)
            except FileExistsError:
                assert os.path.isfile(dst_file)
            except OSError as e:
                if e.errno != 18:
                    if use_symlink_instead_of_copy:
                        logging.warning('Could not hardlink %s to %s, use symlink' % (src, dst))
                        shutil.copy2(src_file, dst_file)
                    else:
                        logging.warning('Could not hardlink %s to %s, use copy' % (src, dst))
                        os.symlink(os.path.abspath(src), dst)
                else:
                    raise e


def default_handle_exception_interrupt_main_thread(func):
    """
    :param func: any function. usually run in another thread.
      If some exception occurs, it will interrupt the main thread (send KeyboardInterrupt to the main thread).
      If this is run in the main thread itself, it will raise SystemExit(1).
    :return: function func wrapped
    """
    import sys
    import _thread
    import threading

    def wrapped_func(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except Exception:
            logging.error("Exception in thread %r:" % threading.current_thread())
            sys.excepthook(*sys.exc_info())
            if threading.current_thread() is not threading.main_thread():
                _thread.interrupt_main()
            raise SystemExit(1)

    return wrapped_func


def dump_all_thread_tracebacks(exclude_thread_ids=None, exclude_self=False, file=sys.stderr):
    """
    :param set[int]|None exclude_thread_ids: set|list of thread.ident to exclude
    :param bool exclude_self:
    """
    from traceback import print_stack, walk_stack
    from multiprocessing.pool import worker as mp_worker
    from multiprocessing.pool import Pool
    from queue import Queue
    import threading

    if not hasattr(sys, "_current_frames"):
        print("Does not have sys._current_frames, cannot get thread tracebacks.", file=file)
        return
    if exclude_thread_ids is None:
        exclude_thread_ids = set()
    exclude_thread_ids = set(exclude_thread_ids)
    if exclude_self:
        exclude_thread_ids.add(threading.current_thread().ident)

    print("", file=file)
    threads = {t.ident: t for t in threading.enumerate()}
    for tid, stack in sorted(sys._current_frames().items()):
        # This is a bug in earlier Python versions.
        # http://bugs.python.org/issue17094
        # Note that this leaves out all threads not created via the threading module.
        if tid not in threads:
            continue
        tags = []
        thread = threads.get(tid)
        if thread:
            assert isinstance(thread, threading.Thread)
            if thread is threading.current_thread():
                tags += ["current"]
            if thread is threading.main_thread():
                tags += ["main"]
            tags += [str(thread)]
        else:
            tags += ["unknown with id %i" % tid]
        print("Thread %s:" % ", ".join(tags), file=file)
        if tid in exclude_thread_ids:
            print("(Excluded thread.)\n", file=file)
            continue
        stack_frames = [f[0] for f in walk_stack(stack)]
        stack_func_code = [f.f_code for f in stack_frames]
        if mp_worker.__code__ in stack_func_code:
            i = stack_func_code.index(mp_worker.__code__)
            if i > 0 and stack_func_code[i - 1] is Queue.get.__code__:
                print("(Exclude multiprocessing idling worker.)\n", file=file)
                continue
        if Pool._handle_tasks.__code__ in stack_func_code:
            i = stack_func_code.index(Pool._handle_tasks.__code__)
            if i > 0 and stack_func_code[i - 1] is Queue.get.__code__:
                print("(Exclude multiprocessing idling task handler.)\n", file=file)
                continue
        if Pool._handle_workers.__code__ in stack_func_code:
            i = stack_func_code.index(Pool._handle_workers.__code__)
            if i == 0:  # time.sleep is native, thus not on the stack
                print("(Exclude multiprocessing idling worker handler.)\n", file=file)
                continue
        if Pool._handle_results.__code__ in stack_func_code:
            i = stack_func_code.index(Pool._handle_results.__code__)
            if i > 0 and stack_func_code[i - 1] is Queue.get.__code__:
                print("(Exclude multiprocessing idling result handler.)\n", file=file)
                continue
        print_stack(stack, file=file)
        print("", file=file)
    print("That were all threads.", file=file)


def format_signum(signum):
    """
    :param int signum:
    :return: string "signum (signame)"
    :rtype: str
    """
    import signal
    signum_to_signame = {
        k: v for v, k in reversed(sorted(signal.__dict__.items()))
        if v.startswith('SIG') and not v.startswith('SIG_')}
    return "%s (%s)" % (signum, signum_to_signame.get(signum, "unknown"))


def signal_handler(signum, frame):
    """
    Prints a message on stdout and dump all thread stacks.

    :param int signum: e.g. signal.SIGUSR1
    :param frame: ignored, will dump all threads
    """
    print("Signal handler: got signal %s" % format_signum(signum), file=sys.stderr)
    dump_all_thread_tracebacks(file=sys.stderr)


def install_signal_handler_if_default(signum, exceptions_are_fatal=False):
    """
    :param int signum: e.g. signal.SIGUSR1
    :param bool exceptions_are_fatal: if True, will reraise any exceptions. if False, will just print a message
    :return: True iff no exception, False otherwise. not necessarily that we registered our own handler
    :rtype: bool
    """
    try:
        import signal
        if signal.getsignal(signum) == signal.SIG_DFL:
            signal.signal(signum, signal_handler)
        return True
    except Exception as exc:
        if exceptions_are_fatal:
            raise
        print("Cannot install signal handler for signal %s, exception %s" % (format_signum(signum), exc))
    return False


def maybe_install_signal_handers():
    import signal
    install_signal_handler_if_default(signal.SIGUSR1)
    install_signal_handler_if_default(signal.SIGUSR2)


class MemoryProfiler:
    def __init__(self, log_stream, line_limit=10, min_change=512000):
        self.log_stream = log_stream
        self.limit = line_limit
        tracemalloc.start()
        self.min_change = min_change
        self.last_total = 0

    def snapshot(self):
        snapshot = tracemalloc.take_snapshot()
        top_stats = snapshot.statistics('lineno')
        total = sum(stat.size for stat in top_stats)
        if abs(self.last_total - total) < self.min_change:
            return

        self.last_total = total

        self.log_stream.write("Top %s lines at %s\n" % (self.limit, time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime())))
        for index, stat in enumerate(top_stats[:self.limit], 1):
            frame = stat.traceback[0]
            # replace "/path/to/module/file.py" with "module/file.py"
            filename = os.sep.join(frame.filename.split(os.sep)[-2:])
            self.log_stream.write("#%s: %s:%s: %.1f KiB\n" % (index, filename, frame.lineno, stat.size / 1024))
            line = linecache.getline(frame.filename, frame.lineno).strip()
            if line:
                self.log_stream.write('    %s\n' % line)

        other = top_stats[self.limit:]
        if other:
            size = sum(stat.size for stat in other)
            self.log_stream.write("%s other: %.1f KiB\n" % (len(other), size / 1024))
        self.log_stream.write("Total allocated size: %.1f KiB\n\n" % (total / 1024))
        self.log_stream.flush()


class EnvironmentModifier:
    """
    A class to cleanup the environment before a job starts
    """

    def __init__(self):
        self.keep_vars = set()
        self.set_vars = {}

    def keep(self, var):
        if type(var) == str:
            self.keep_vars.add(var)
        else:
            self.keep_vars.update(var)

    def set(self, var, value=None):
        if type(var) == dict:
            self.set_vars.update(var)
        else:
            self.set_vars[var] = value

    def modify_environment(self):
        import os
        import string

        orig_env = dict(os.environ)
        keys = list(os.environ.keys())
        for k in keys:
            if k not in self.keep_vars:
                del os.environ[k]
        for k, v in self.set_vars.items():
            if type(v) == str:
                os.environ[k] = string.Template(v).substitute(orig_env)
            else:
                os.environ[k] = str(v)

        for k, v in os.environ.items():
            logging.debug('environment var %s=%s' % (k, v))

    def __repr__(self):
        return repr(self.keep_vars) + ' ' + repr(self.set_vars)