#!/usr/bin/env python
# coding: utf-8

import os
import sys
import json
import math
import time
import shutil
import curses
import logging

from io import StringIO
from argparse import ArgumentParser
from itertools import chain
from itertools import islice
from multiprocessing import Queue
from multiprocessing import Process
from multiprocessing import current_process

from gevent import pool
from gevent import monkey
from gevent import timeout

monkey.patch_socket()


def count_file_linenum(filename):
    """ Count the number of lines in file

    Args:
        filename: A file path need to count the number of lines

    Returns:
        A Integer, the number of lines in file.
    """
    with open(filename) as f:
        n = f.readlines().__len__()
    return n


def split_file_by_linenum(filename, linenum_of_perfile=30*10**6):
    """ Split one file into serval files with a solit line number

    Args:
        filename: A file path need to split with linenum
        linenum_of_perfile: The line number of per file

    Returns:
        Filenames list has already splited.
    """
    def chunks(iterable, n):
        iterable = iter(iterable)
        while True:
            yield chain([next(iterable)], islice(iterable, n-1))

    filenames = []
    with open(filename) as f:
        for i, lines in enumerate(chunks(f, linenum_of_perfile)):
            _ = '{}_{:02d}'.format(filename, i)
            filenames.append(_)
            with open(_, 'w') as wf:
                wf.writelines(lines)
    return filenames


def split_file_by_filenum(filename, filenum):
    """ Split one file into serval files with a number of files """
    if filenum <= 1:
        filenames = [filename]
    else:
        linenum = count_file_linenum(filename)
        if linenum < filenum:
            raise OptException('proc_num more than line number of seed file')
        linenum_of_perfile = int(math.ceil(linenum / float(filenum)))
        filenames = split_file_by_linenum(filename, linenum_of_perfile)
    return filenames


class OptException(Exception):
    pass


class Config(object):
    scanname = None  # scanning task name
    seedfile = None  # seed file path to process
    task_dir = None  # task files stored directory
    proc_num = None  # process number to use
    pool_size = None  # pool task size of per process
    pool_timeout = None  # pool task timeout of per process
    poc_file = None  # poc file path
    poc_func = None  # function to run in poc file
    poc_callback = None  # callback function in poc file

    enable_console = None  # console monitor on/off

    scan_func = None  # function method instance
    scan_callback = None  # callback method instance

    def from_keys(self, keys):
        for k, v in keys.items():
            if hasattr(self, k) and v is not None:
                setattr(self, k, v)

    def from_jsonfile(self, jsonfile):
        """ Load options from json file """
        with open(jsonfile) as f:
            content = f.read()
        keys = json.loads(content)
        self.from_keys(keys)

    @property
    def __dict__(self):
        return dict(scanname=self.scanname,
                    seedfile=self.seedfile,
                    task_dir=self.task_dir,
                    proc_num=self.proc_num,
                    pool_size=self.pool_size,
                    pool_timeout=self.pool_timeout,
                    poc_file=self.poc_file,
                    poc_func=self.poc_func,
                    poc_callback=self.poc_callback)


class ConsoleMonitor(object):
    """ Console monitor with buildin module "curses"
    dococument: https://docs.python.org/2/library/curses.html
    """
    def __init__(self, config, processes, progress_queue, output_queue):
        self.config = config
        self.processes = processes
        self.progress_queue = progress_queue
        self.output_queue = output_queue

        self.stdscr = None
        self.pgsscr = None
        self.cntscr = None
        self.optscr = None

        self.stdscr_size = None
        self.pgsscr_size = None
        self.cntscr_size = None
        self.optscr_size = None

        self.task_total = None
        self.task_num = None
        self.start_time = time.time()
        self.progress = {}
        self.contents = []

        self.init_scr()

    def init_scr(self):
        self.stdscr = curses.initscr()
        curses.noecho()
        curses.curs_set(0)

        self.stdscr_size = self.stdscr.getmaxyx()
        self.task_total = count_file_linenum(self.config.seedfile)

        self.pgsscr_size = (self.config.proc_num + 2, 40)
        self.pgsscr = curses.newpad(*self.pgsscr_size)
        self.cntscr_size = (4, 40)
        self.cntscr = curses.newpad(*self.cntscr_size)
        self.optscr_size = (18, 80)
        self.optscr = curses.newpad(*self.optscr_size)

    def build_progress_screen(self):
        c_rows = max(self.config.proc_num + 2, 6)
        c_columns = (40 if self.stdscr_size[1] / 2 < 40
                     else self.stdscr_size[1] / 2)
        c_rows, c_columns = int(c_rows), int(c_columns)
        self.pgsscr_size = (c_rows, c_columns)
        self.pgsscr.resize(*self.pgsscr_size)
        bar_max = (25 if self.pgsscr_size[1] < 40
                   else self.pgsscr_size[1] - 15)

        while not self.progress_queue.empty():
            proc_name, count, task_total = self.progress_queue.get()
            self.progress[proc_name] = count
            i = int(proc_name.split('-')[1])
            pct = float(count) / task_total
            bar = ('='*int(pct*bar_max)).ljust(bar_max)
            o = ' {:<2d} [{}{:>6.2f}%] '.format(i, bar, pct*100)
            self.pgsscr.addstr(i, 0, o)

        self.pgsscr.refresh(0, 0, 0, 0, c_rows, c_columns)

    def build_status_screen(self):
        c_rows = max(self.config.proc_num + 2, 6)
        c_columns = (40 if self.stdscr_size[1] / 2 < 40
                     else self.stdscr_size[1] / 2)
        c_rows, c_columns = int(c_rows), int(c_columns)
        self.cntscr_size = (c_rows, c_columns)
        self.task_num = sum([v for k, v in self.progress.items()])
        running_time = time.strftime('%H:%M:%S',
                                     time.gmtime(time.time()-self.start_time))
        self.cntscr.resize(*self.cntscr_size)
        self.cntscr.addstr(1, 0, 'Total: {}'.format(self.task_total))
        self.cntscr.addstr(2, 0, 'Current: {}'.format(self.task_num))
        self.cntscr.addstr(4, 0, 'Running Time: {}'.format(running_time))
        self.cntscr.refresh(0, 0, 0, c_columns, c_rows, c_columns*2)

    def build_output_screen(self):
        without_stream_logger = logging.getLogger('output.without.stream')
        offset_rows = int(max(self.pgsscr_size[0], self.cntscr_size[0]))
        c_rows = self.stdscr_size[0] - offset_rows
        c_columns = self.stdscr_size[1]
        c_rows, c_columns = int(c_rows), int(c_columns)

        self.optscr_size = (c_rows, c_columns)
        self.optscr.resize(*self.optscr_size)
        self.optscr.border(1, 1, 0, 0)

        if len(self.contents) > c_rows:
            self.contents = self.contents[len(self.contents)-c_rows+1:]
        else:
            self.contents.extend(['']*(c_rows-len(self.contents)-1))

        while not self.output_queue.empty():
            proc_name, output = self.output_queue.get()
            # o = ('[{}]({}):{}'
            #      .format(time.strftime('%T %d,%B %Y', time.localtime()),
            #              proc_name.strip(), output))
            o = '{}'.format(output)
            without_stream_logger.info(o)

            self.contents = self.contents[1:]
            self.contents.append(o if len(o) < c_columns else o[:c_columns])
            self.optscr.move(0, 0)
            self.optscr.clrtobot()
            for i, v in enumerate(self.contents):
                self.optscr.addstr(i, 0, v)

            self.optscr.refresh(0, 0, offset_rows, 0,
                                c_rows + offset_rows, c_columns)

    def run(self):
        while any(_.is_alive() for _ in self.processes):
            time.sleep(0.1)
            self.stdscr_size = self.stdscr.getmaxyx()
            self.build_progress_screen()
            self.build_status_screen()
            self.build_output_screen()

            # terminate manually when all tasks finished
            if self.task_num == self.task_total:
                for _ in self.processes:
                    _.terminate()

        self.stdscr.addstr(self.stdscr_size[0] - 2, 0,
                           'Done! please type "q" to exit.')
        self.stdscr.refresh()
        while self.stdscr.getch() != ord('q'):
            time.sleep(1)

        curses.endwin()


class ProcessIO(StringIO):
    def __init__(self, output_queue, *args, **kwargs):
        super(StringIO, self).__init__(*args, **kwargs)
        self.output_queue = output_queue
        self.proc_name = current_process().name

    def write(self, s):
        if s == '\n':
            return
        self.output_queue.put((self.proc_name, s.strip()))


class ProcessTask(object):
    def __init__(self, scan_func, pool_size, pool_timeout):
        self.scan_func = scan_func
        self.pool_size = pool_size
        self.pool_timeout = pool_timeout

    @staticmethod
    def callback(result):
        return result

    def pool_task_with_timeout(self, line):
        seed = line.strip()
        result = dict(seed=seed, data=None, exception=None)
        try:
            data = timeout.with_timeout(self.pool_timeout,
                                        self.scan_func,
                                        seed)
        except (Exception, timeout.Timeout) as ex:
            result['exception'] = str(ex)
        else:
            result['data'] = data
        return result

    def run(self, seedfile, progress_queue, output_queue):
        task_total = count_file_linenum(seedfile)
        proc_name = current_process().name
        sys.stdout = ProcessIO(output_queue)

        def progress_tracking(greenlet):
            count = getattr(progress_tracking, 'count', 0) + 1
            setattr(progress_tracking, 'count', count)
            progress_queue.put((proc_name, count, task_total))
            return greenlet

        po = pool.Pool(self.pool_size)
        with open(seedfile) as f:
            for line in f:
                g = po.apply_async(func=self.pool_task_with_timeout,
                                   args=(line, ),
                                   kwds=None,
                                   callback=self.callback)
                g.link(progress_tracking)
                po.add(g)

        try:
            po.join()
        except (KeyboardInterrupt, SystemExit) as ex:
            print(str(ex))
            po.kill()


class Launcher(object):
    def __init__(self, options):
        self.config = Config()
        self._init_conf(options)
        self._init_env()
        self._init_mod()
        self._init_logger()

    def _init_conf(self, options):
        config = options.CONFIG
        opts = vars(options)
        opts.pop('CONFIG')
        opts = dict((k.lower(), v) for k, v in opts.items())
        self.config.from_keys(opts)
        if config:
            self.config.from_jsonfile(config)

        # check options required
        for k, v in opts.items():
            if hasattr(self.config, k):
                value = getattr(self.config, k)
                if value is None:
                    raise OptException('{} option required, '
                                       'use -h for help'.format(k))

    def _init_env(self):
        cwd = os.getcwd()
        task_dir = os.path.realpath(os.path.join(cwd, self.config.task_dir))
        seedfile = os.path.realpath(os.path.join(cwd, self.config.seedfile))
        poc_file = os.path.realpath(os.path.join(cwd, self.config.poc_file))

        try:
            self.config.proc_num = int(self.config.proc_num)
            self.config.pool_size = int(self.config.pool_size)
            self.config.pool_timeout = int(self.config.pool_timeout)
        except ValueError as ex:
            raise OptException('wrong option type, "{}"'.format(str(ex)))

        if not os.path.exists(seedfile):
            raise OptException('seed file not exists, {}'.format(seedfile))
        if not os.path.exists(poc_file):
            raise OptException('poc file not exists, {}'.format(poc_file))

        if not os.path.exists(task_dir):
            os.makedirs(task_dir)

        # timestamp = time.strftime('%Y%m%d-%H%M%S', time.localtime())
        # task_runtime_dir = os.path.join(task_dir, timestamp)
        # if not os.path.exists(task_runtime_dir):
        #     os.makedirs(task_runtime_dir)
        task_runtime_dir = task_dir

        shutil.copy(seedfile, task_runtime_dir)
        self.config.seedfile = os.path.realpath(os.path.join(task_runtime_dir,
                                                os.path.basename(seedfile)))
        shutil.copy(poc_file, task_runtime_dir)
        self.config.poc_file = os.path.realpath(os.path.join(task_runtime_dir,
                                                os.path.basename(poc_file)))
        self.config.task_dir = task_dir

        # dump options to json file in task directory
        d_opts = vars(self.config)
        conffile = os.path.join(task_runtime_dir, 'config.json')
        with open(conffile, 'w') as f:
            f.write(json.dumps(d_opts, indent=4, sort_keys=True))

        os.chdir(task_runtime_dir)

    def _init_mod(self):
        sys.path.append(
            os.path.abspath(os.path.dirname(self.config.poc_file)))
        poc_name = os.path.splitext(os.path.basename(self.config.poc_file))[0]
        poc_mod = __import__(poc_name)
        self.config.scan_func = getattr(poc_mod, self.config.poc_func)
        if self.config.poc_callback:
            self.config.scan_callback = getattr(poc_mod,
                                                self.config.poc_callback)

    def _init_logger(self):
        output_file_handler = logging.FileHandler('output.log', mode='w')
        output_file_handler.setFormatter(logging.Formatter('%(message)s'))
        output_file_handler.setLevel(logging.INFO)

        output_stream_handdler = logging.StreamHandler()
        output_stream_handdler.setFormatter(logging.Formatter('%(message)s'))
        output_stream_handdler.setLevel(logging.INFO)

        with_stream_logger = logging.getLogger('output.with.stream')
        without_stream_logger = logging.getLogger('output.without.stream')

        with_stream_logger.addHandler(output_file_handler)
        with_stream_logger.addHandler(output_stream_handdler)
        with_stream_logger.setLevel(logging.DEBUG)

        without_stream_logger.addHandler(output_file_handler)
        without_stream_logger.setLevel(logging.DEBUG)

    def run(self):
        """ Start ProcessTask main function """
        filenames = split_file_by_filenum(self.config.seedfile,
                                          self.config.proc_num)
        output_queue = Queue()
        progress_queue = Queue()
        processes = []
        w = ProcessTask(self.config.scan_func,
                        self.config.pool_size,
                        self.config.pool_timeout)
        if self.config.scan_callback:
            w.callback = self.config.scan_callback

        for i, filename in enumerate(filenames):
            proc_name = 'Worker-{:<2d}'.format(i+1)
            p = Process(name=proc_name,
                        target=w.run,
                        args=(filename, progress_queue, output_queue))
            if p not in processes:
                processes.append(p)

        for p in processes:
            p.start()

        if self.config.enable_console:
            monitor = ConsoleMonitor(self.config,
                                     processes,
                                     progress_queue,
                                     output_queue)
            monitor.run()

        else:
            progress = {}
            task_total = count_file_linenum(self.config.seedfile)
            task_num = 0
            with_stream_logger = logging.getLogger('output.with.stream')

            while any(p.is_alive() for p in processes):
                time.sleep(0.1)
                while not progress_queue.empty():
                    proc_name, count, task_total = progress_queue.get()
                    progress[proc_name] = count
                    task_num = sum([v for k, v in progress.items()])
                while not output_queue.empty():
                    proc_name, output = output_queue.get()
                    with_stream_logger.info('{}'.format(output))

                if task_num == task_total:
                    for _ in processes:
                        _.terminate()


DESC = 'A lightweight batch scanning framework based on gevent.'


def commands():
    parser = ArgumentParser(description=DESC)

    parser.add_argument('-c', '--config', dest='CONFIG', default=None,
                        type=str, help='config file of launcher')
    parser.add_argument('-n', '--scanname', dest='SCANNAME',
                        type=str, help='alias name of launcher')
    parser.add_argument('-t', '--seedfile', dest='SEEDFILE',
                        type=str, help='seed file path to scan')
    parser.add_argument('-r', '--poc-file', dest='POC_FILE',
                        type=str, help='poc file path to load')
    parser.add_argument('-f', '--poc-func', dest='POC_FUNC',
                        default='run', type=str,
                        help='function name to run in poc file')
    parser.add_argument('-b', '--poc-callback', dest='POC_CALLBACK',
                        default='callback', type=str,
                        help='callback function name in poc file')

    parser.add_argument('--task-dir', dest='TASK_DIR',
                        default='tasks/', type=str,
                        help='task files stored directory (default: "tasks/")')
    parser.add_argument('--proc-num', dest='PROC_NUM',
                        default=4, type=int,
                        help='process numbers to run (default: 4)')
    parser.add_argument('--pool-size', dest='POOL_SIZE',
                        default=100, type=int,
                        help='pool size in per process (default: 100)')
    parser.add_argument('--pool-timeout', dest='POOL_TIMEOUT',
                        default=180, type=int,
                        help='pool timeout in per process')

    parser.add_argument('--enable-console', dest='ENABLE_CONSOLE',
                        action='store_true', default=False,
                        help='enable real-time console monitor')

    return parser.parse_args()


if __name__ == '__main__':
    launcher = Launcher(commands())
    launcher.run()