import argparse
import collections
import dataclasses
import enum
import heapq
import importlib
import importlib.util
import inspect
import json
import linecache
import multiprocessing
import multiprocessing.queues
import os
import os.path
import queue
import shutil
import signal
import sys
import time
import traceback
from typing import *
from typing import TextIO

from crosshair.localhost_comms import StateUpdater, read_states
from crosshair.core_and_libs import AnalysisMessage, AnalysisOptions, MessageType, analyzable_members, analyze_module, analyze_any, exception_line_in_file
from crosshair.util import debug, extract_module_from_file, set_debug, CrosshairInternal, load_file, load_by_qualname, NotFound, ErrorDuringImport
import crosshair.core_and_libs

def command_line_parser() -> argparse.ArgumentParser:
    common = argparse.ArgumentParser(add_help=False)
    common.add_argument('--verbose', '-v', action='store_true')
    common.add_argument('--per_path_timeout', type=float)
    common.add_argument('--per_condition_timeout', type=float)
    parser = argparse.ArgumentParser(description='CrossHair Analysis Tool')
    subparsers = parser.add_subparsers(help='sub-command help', dest='action')
    check_parser = subparsers.add_parser(
        'check', help='Analyze one or more files', parents=[common])
    check_parser.add_argument('--report_all', action='store_true')
    check_parser.add_argument('files', metavar='F', type=str, nargs='+',
                              help='files or fully qualified modules, classes, or functions')
    watch_parser = subparsers.add_parser(
        'watch', help='Continuously watch and analyze files', parents=[common])
    watch_parser.add_argument('files', metavar='F', type=str, nargs='+',
                              help='files or directories to analyze')
    showresults_parser = subparsers.add_parser(
        'showresults', help='Display results from a currently running `watch` command', parents=[common])
    showresults_parser.add_argument('files', metavar='F', type=str, nargs='+',
                                    help='files or directories to analyze')
    return parser

def mtime(path: str) -> Optional[float]:
    try:
        return os.stat(path).st_mtime
    except FileNotFoundError:
        return None

def process_level_options(command_line_args: argparse.Namespace) -> AnalysisOptions:
    options = AnalysisOptions()
    for optname in ('per_path_timeout', 'per_condition_timeout', 'report_all'):
        arg_val = getattr(command_line_args, optname, False)
        if arg_val is not None:
            setattr(options, optname, arg_val)
    return options


@dataclasses.dataclass(init=False)
class WatchedMember:
    qual_name: str
    content_hash: int
    last_modified: float

    def get_member(self):
        return load_by_qualname(self.qual_name)

    def __init__(self, qual_name: str, body: str) -> None:
        self.qual_name = qual_name
        self.content_hash = hash(body)
        self.last_modified = time.time()

    def consider_new(self, new_version: 'WatchedMember') -> bool:
        if self.content_hash != new_version.content_hash:
            self.content_hash = new_version.content_hash
            self.last_modified = time.time()
            return True
        return False


WorkItemInput = Tuple[str, # (filename)
                      AnalysisOptions, float]  # (float is a deadline)
WorkItemOutput = Tuple[WatchedMember, Counter[str], List[AnalysisMessage]]

def import_error_msg(err: ErrorDuringImport) -> AnalysisMessage:
    orig, frame = err.args
    return AnalysisMessage(MessageType.IMPORT_ERR, str(orig),
                           frame.filename, frame.lineno, 0, '')

def pool_worker_main(item: WorkItemInput, output: multiprocessing.queues.Queue) -> None:
    try:
        # TODO figure out a more reliable way to suppress this. Redirect output?
        # Ignore ctrl-c in workers to reduce noisy tracebacks (the parent will kill us):
        signal.signal(signal.SIGINT, signal.SIG_IGN)

        if hasattr(os, 'nice'): # analysis should run at a low priority
            os.nice(10)
        set_debug(False)
        filename, options, deadline = item
        stats: Counter[str] = Counter()
        options.stats = stats
        _, module_name = extract_module_from_file(filename)
        try:
            module = load_by_qualname(module_name)
        except NotFound:
            return
        except ErrorDuringImport as e:
            output.put((filename, stats, [import_error_msg(e)]))
            debug(f'Not analyzing "{filename}" because import failed: {e}')
            return
        messages = analyze_any(module, options)
        output.put((filename, stats, messages))
    except BaseException as e:
        raise CrosshairInternal(
            'Worker failed while analyzing ' + filename) from e


class Pool:
    _workers: List[Tuple[multiprocessing.Process, WorkItemInput]]
    _work: List[WorkItemInput]
    _results: multiprocessing.queues.Queue
    _max_processes: int

    def __init__(self, max_processes: int) -> None:
        self._workers = []
        self._work = []
        self._results = multiprocessing.Queue()
        self._max_processes = max_processes

    def _spawn_workers(self):
        work_list = self._work
        workers = self._workers
        while work_list and len(self._workers) < self._max_processes:
            work_item = work_list.pop()
            process = multiprocessing.Process(
                target=pool_worker_main, args=(work_item, self._results))
            workers.append((process, work_item))
            process.start()

    def _prune_workers(self, curtime):
        for worker, item in self._workers:
            (_, _, deadline) = item
            if worker.is_alive() and curtime > deadline:
                debug('Killing worker over deadline', worker)
                worker.terminate()
                time.sleep(0.5)
                if worker.is_alive():
                    worker.kill()
                    worker.join()
        self._workers = [(w, i) for w, i in self._workers if w.is_alive()]

    def terminate(self):
        self._prune_workers(float('+inf'))
        self._work = []
        self._results.close()

    def garden_workers(self):
        self._prune_workers(time.time())
        self._spawn_workers()

    def is_working(self):
        return self._workers or self._work

    def submit(self, item: WorkItemInput) -> None:
        self._work.append(item)

    def has_result(self):
        return not self._results.empty()

    def get_result(self, timeout: float) -> Optional[WorkItemOutput]:
        try:
            return self._results.get(timeout=timeout)
        except queue.Empty:
            return None


def worker_initializer():
    """Ignore CTRL+C in the worker process."""
    signal.signal(signal.SIGINT, signal.SIG_IGN)

def analyzable_filename(filename: str) -> bool:
    if not filename.endswith('.py'):
        return False
    lead_char = filename[0]
    if (not lead_char.isalpha()) and (not lead_char.isidentifier()):
        # (skip temporary editor files, backups, etc)
        debug(
            f'Skipping {filename} because it begins with a special character.')
        return False
    if filename in ('setup.py',):
        debug(
            f'Skipping {filename} because files with this name are not usually import-able.')
        return False
    return True

def walk_paths(paths: Iterable[str]) -> Iterable[str]:
    for name in paths:
        if not os.path.exists(name):
            print(f'Watch path "{name}" does not exist.', file=sys.stderr)
            sys.exit(1)
        if os.path.isdir(name):
            for (dirpath, dirs, files) in os.walk(name):
                for curfile in files:
                    if analyzable_filename(curfile):
                        yield os.path.join(dirpath, curfile)
        else:
            yield name

class Watcher:
    _paths: Set[str]
    _pool: Pool
    _modtimes: Dict[str, float]
    _options: AnalysisOptions
    _next_file_check: float = 0.0
    _change_flag: bool = False

    def __init__(self, options: AnalysisOptions, files: Iterable[str], state_updater: StateUpdater):
        self._paths = set(files)
        self._state_updater = state_updater
        self._pool = self.startpool()
        self._modtimes = {}
        self._options = options
        _ = list(walk_paths(self._paths)) # just to force an exit if we can't find a path

    def startpool(self) -> Pool:
        return Pool(multiprocessing.cpu_count() - 1)

    def run_iteration(self,
                      max_condition_timeout=0.5) -> Iterator[
                          Tuple[Counter[str], List[AnalysisMessage]]]:
        debug(f'starting pass '
              f'with a condition timeout of {max_condition_timeout}')
        debug('Files:', self._modtimes.keys())
        pool = self._pool
        for filename in self._modtimes.keys():
            worker_timeout = max(10.0, max_condition_timeout * 20.0)
            options = dataclasses.replace(
                self._options, per_condition_timeout=max_condition_timeout)
            pool.submit((filename, options, time.time() + worker_timeout))

        pool.garden_workers()
        while pool.is_working():
            result = pool.get_result(timeout=1.0)
            if result is not None:
                (_, counters, messages) = result
                yield (counters, messages)
                if pool.has_result():
                    continue
            change_detected = self.check_changed()
            if change_detected:
                self._change_flag = True
                debug('Aborting iteration on change detection')
                pool.terminate()
                self._pool = self.startpool()
                return
            pool.garden_workers()
        debug('Worker pool tasks complete')
        yield (Counter(), [])

    def run_watch_loop(self) -> NoReturn:
        restart = True
        stats: Counter[str] = Counter()
        active_messages: Dict[Tuple[str, int], AnalysisMessage]
        while True:
            if restart:
                clear_screen()
                clear_line('-')
                line = f'  Analyzing {len(self._modtimes)} files.          \r'
                sys.stdout.write(color(line, AnsiColor.OKBLUE))
                max_condition_timeout = 0.5
                restart = False
                stats = Counter()
                active_messages = {}
            else:
                time.sleep(0.5)
                max_condition_timeout *= 2
            for curstats, messages in self.run_iteration(max_condition_timeout):
                debug('stats', curstats, messages)
                stats.update(curstats)
                if messages_merged(active_messages, messages):
                    self._state_updater.update(json.dumps({
                        'version': 1,
                        'time': time.time(),
                        'messages': [m.toJSON() for m in active_messages.values()]}))
                    linecache.checkcache()
                    clear_screen()
                    for message in active_messages.values():
                        lines = long_describe_message(message)
                        if lines is None:
                            continue
                        clear_line('-')
                        print(lines, end='')
                    clear_line('-')
                line = f'  Analyzed {stats["num_paths"]} paths in {len(self._modtimes)} files.          \r'
                sys.stdout.write(color(line, AnsiColor.OKBLUE))
            if self._change_flag:
                self._change_flag = False
                restart = True
                line = f'  Restarting analysis over {len(self._modtimes)} files.          \r'
                sys.stdout.write(color(line, AnsiColor.OKBLUE))

    def check_changed(self) -> bool:
        if time.time() < self._next_file_check:
            return False
        modtimes = self._modtimes
        changed = False
        for curfile in walk_paths(self._paths):
            cur_mtime = mtime(curfile)
            if cur_mtime == modtimes.get(curfile):
                continue
            changed = True
            if cur_mtime is None:
                del modtimes[curfile]
            else:
                modtimes[curfile] = cur_mtime
        self._next_file_check = time.time() + 1.0
        if not changed:
            return False
        return True


def clear_screen():
    print("\n" * shutil.get_terminal_size().lines, end='')


def clear_line(ch=' '):
    sys.stdout.write(ch * shutil.get_terminal_size().columns)


class AnsiColor(enum.Enum):
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKGREEN = '\033[92m'
    WARNING = '\033[93m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'


def color(text: str, *effects: AnsiColor) -> str:
    return ''.join(e.value for e in effects) + text + AnsiColor.ENDC.value


def messages_merged(messages: MutableMapping[Tuple[str, int], AnalysisMessage],
                    new_messages: Iterable[AnalysisMessage]) -> bool:
    any_change = False
    for message in new_messages:
        key = (message.filename, message.line)
        if key not in messages or messages[key] != message:
            messages[key] = message
            any_change = True
    return any_change


def watch(args: argparse.Namespace, options: AnalysisOptions) -> int:
    # Avoid fork() because we've already imported the code we're watching:
    multiprocessing.set_start_method('spawn')

    if not args.files:
        print('No files or directories given to watch', file=sys.stderr)
        return 1
    try:
        with StateUpdater() as state_updater:
            watcher = Watcher(options, args.files, state_updater)
            watcher.check_changed()
            watcher.run_watch_loop()
    except KeyboardInterrupt:
        watcher._pool.terminate()
        print()
        print('I enjoyed working with you today!')
        return 0


def format_src_context(filename: str, lineno: int) -> str:
    amount = 3
    line_numbers = range(max(1, lineno - amount), lineno + amount + 1)
    output = [f'{filename}:{lineno}:\n']
    for curline in line_numbers:
        text = linecache.getline(filename, curline)
        if text == '': # (actual empty lines have a newline)
            continue
        output.append('>' + color(text, AnsiColor.WARNING)
                      if lineno == curline else '|' + text)
    return ''.join(output)


def long_describe_message(message: AnalysisMessage) -> Optional[str]:
    tb, desc, state = message.traceback, message.message, message.state
    desc = desc.replace(' when ', '\nwhen ')
    context = format_src_context(message.filename, message.line)
    intro = ''
    if state <= MessageType.CANNOT_CONFIRM:  # type: ignore
        return None
    elif message.state == MessageType.PRE_UNSAT:
        # TODO: This is disabled as unsat reasons are too common
        # intro = "I am having trouble finding any inputs that meet this precondition."
        return None
    elif message.state == MessageType.POST_ERR:
        intro = "I got an error while checking your postcondition."
    elif message.state == MessageType.EXEC_ERR:
        intro = "I found an exception while running your function."
    elif message.state == MessageType.POST_FAIL:
        intro = "I was able to make your postcondition return False."
    elif message.state == MessageType.SYNTAX_ERR:
        intro = "One of your conditions isn't a valid python expression."
    elif message.state == MessageType.IMPORT_ERR:
        intro = "I couldn't import a file."
    intro = color(intro, AnsiColor.FAIL)
    return f'{tb}\n{intro}\n{context}\n{desc}\n'


def short_describe_message(message: AnalysisMessage, options: AnalysisOptions) -> Optional[str]:
    desc = message.message
    if message.state <= MessageType.PRE_UNSAT:  # type: ignore
        if options.report_all:
            return '{}:{}:{}:{}'.format(message.filename, message.line, 'info', desc)
        return None
    if message.state == MessageType.POST_ERR:
        desc = 'Error while evaluating post condition: ' + desc
    return '{}:{}:{}:{}'.format(message.filename, message.line, 'error', desc)


def showresults(args: argparse.Namespace, options: AnalysisOptions) -> int:
    messages_by_file: Dict[str, List[AnalysisMessage]] = collections.defaultdict(list)
    states = list(read_states())
    for fname, content in states:
        debug('Found watch state file at ', fname)
        state = json.loads(content)
        for message in state['messages']:
            messages_by_file[message['filename']].append(AnalysisMessage.fromJSON(message))
    debug('Found results for these files: [', ', '.join(messages_by_file.keys()), ']')
    for name in walk_paths(args.files):
        name = os.path.abspath(name)
        debug('Checking file ', name)
        for message in messages_by_file[name]:
            desc = short_describe_message(message, options)
            debug('Describing ', message)
            if desc is not None:
                print(desc)
    return 0

def check(args: argparse.Namespace, options: AnalysisOptions, stdout: TextIO) -> int:
    any_problems = False
    for name in args.files:
        entity: object
        try:
            entity = load_file(name) if name.endswith('.py') else load_by_qualname(name)
        except ErrorDuringImport as e:
            stdout.write(str(short_describe_message(import_error_msg(e), options)) + '\n')
            any_problems = True
            continue
        debug('Check ', getattr(entity, '__name__', str(entity)))
        for message in analyze_any(entity, options):
            line = short_describe_message(message, options)
            if line is None:
                continue
            stdout.write(line + '\n')
            debug('Traceback for output message:\n', message.traceback)
            if message.state > MessageType.PRE_UNSAT:
                any_problems = True
    return 2 if any_problems else 0


def main() -> None:
    args = command_line_parser().parse_args()
    set_debug(args.verbose)
    options = process_level_options(args)
    if sys.path and sys.path[0] != '':
        # fall back to current directory to look up modules
        sys.path.append('')
    if args.action == 'check':
        exitcode = check(args, options, sys.stdout)
    elif args.action == 'showresults':
        exitcode = showresults(args, options)
    elif args.action == 'watch':
        exitcode = watch(args, options)
    else:
        print(f'Unknown action: "{args.action}"', file=sys.stderr)
        exitcode = 1
    sys.exit(exitcode)


if __name__ == '__main__':
    main()