import functools as ft
import inspect
import json
import re
import time
from contextlib import contextmanager
from imaplib import CRLF, Time2Internaldate

from gevent import Timeout
from gevent.lock import RLock
from gevent.pool import Pool

from . import conf, fn_desc, fn_time, log

commands = {}
pool = {}


class Error(Exception):
    def __repr__(self):
        return '%s.%s: %s' % (__name__, self.__class__.__name__, self.args)


def using(client, box, readonly=True, name='con', reuse=True, parent=False):
    @contextmanager
    def use_or_create(kw):
        if kw.get(name):
            yield
            return

        if reuse:
            key = conf['USER'], client, box
            if key not in pool:
                pool[key] = client(None)
            con = pool[key]
            parent_orig = con.parent
            if not con.parent and box:
                try:
                    con.select(box, readonly)
                except con.abort as e:
                    # probably connection is already expired
                    # try to create new one
                    log.error(e)
                    pool[key] = client(None)
                    con = pool[key]
                    con.parent = parent_orig
                    con.select(box, readonly)
                if parent:
                    con.parent = parent
            if name:
                kw[name] = con
            yield
            if con.parent:
                con.parent = parent_orig
            return

        with client(box, readonly=readonly) as con:
            if name:
                kw[name] = con
            yield

    def inner_gen(*a, **kw):
        with use_or_create(kw):
            yield from wrapper.fn(*a, **kw)

    def inner_fn(*a, **kw):
        with use_or_create(kw):
            return wrapper.fn(*a, **kw)

    def wrapper(fn):
        wrapper.fn = fn
        inner = inner_gen if inspect.isgeneratorfunction(fn) else inner_fn
        return ft.wraps(fn)(inner)
    return wrapper


def clean_pool(user=None):
    if user is None:
        user = conf['USER']

    for key in list(pool.keys()):
        if key[0] != user:
            continue
        con = pool.pop(key, None)
        if con:
            con.logout()


def cmd_locked(func):
    @ft.wraps(func)
    def inner(con, *a, **kw):
        with con.lock:
            return fn_time(func)(con, *a, **kw)
    return inner


def cmd_writable(func):
    @ft.wraps(func)
    def inner(con, *a, **kw):
        if con.is_readonly:
            raise ValueError('%s must be writable' % con)
        return func(con, *a, **kw)
    return inner


def cmd_error(func):
    @ft.wraps(func)
    def inner(con, *a, **kw):
        try:
            return func(con, *a, **kw)
        except con.error as e:
            raise Error(e)
    return inner


def check(res):
    typ, data = res
    if typ != 'OK':
        raise Error(typ, data)
    return data


def command(*, name=None, lock=True, writable=False, dovecot=False):
    def inner(func):
        if name:
            func.name = name
        else:
            func.name = func.__name__

        if lock:
            func = cmd_locked(func)

        func = cmd_error(func)
        commands[func] = {'writable': writable, 'dovecot': dovecot}
        return func
    return inner


class Conn:
    def defaults(self):
        self.current_box = None
        self.flags = None

    def __repr__(self):
        return str(self)

    def __str__(self):
        return '%s{%r, %r}' % (
            self.__class__.__name__, self.username, self.current_box
        )


class Ctx:
    def __init__(self, con):
        self._con = con
        self.parent = False

    def __repr__(self):
        return str(self)

    def __str__(self):
        return 'Ctx:%s' % (self._con)

    @property
    def username(self):
        return self._con.username

    @property
    def box(self):
        return self._con.current_box

    @property
    def is_readonly(self):
        return self._con.is_readonly

    @property
    def abort(self):
        return self._con.abort

    @property
    def error(self):
        return self._con.error

    @property
    def flags(self):
        return self._con.flags

    @property
    def uidnext(self):
        return self._con.uidnext

    @property
    def uidvalidity(self):
        return self._con.uidvalidity

    @property
    def highestmodseq(self):
        return self._con.highestmodseq

    def __enter__(self):
        return self

    def __exit__(self, *args):
        self.logout()


def client(connect, *, writable=False, dovecot=False, debug=None):
    def start():
        con = connect()
        con.debug = conf['DEBUG_IMAP'] if debug is None else debug
        con.lock = RLock()
        con.new = new
        return con

    def new():
        c = start()
        if con.current_box:
            c.select(con.current_box, con.is_readonly)
        return c

    connect = fn_time(connect, '{0.__module__}.{0.__name__}'.format(connect))
    con = start()
    ctx = Ctx(con)

    for cmd, opts in commands.items():
        if not dovecot and opts['dovecot']:
            continue
        elif not writable and opts['writable']:
            continue
        setattr(ctx, cmd.name, ft.partial(cmd, con))
    return ctx


def login(con, username, password):
    try:
        return check(con.login(username, password))
    except con.error as e:
        raise Error(e)


@contextmanager
def _cmd(con, name):
    tag = con._new_tag()

    def start(args):
        if isinstance(args, str):
            args = args.encode()
        return con.send(b'%s %s%s' % (tag, name.encode(), args))
    yield tag, start, lambda: con._command_complete(name, tag)


def _mdkey(key):
    if not key.startswith('/private'):
        key = '/private/%s' % key
    return key


@command(dovecot=True, writable=True)
def setmetadata(con, box, key, value):
    key = _mdkey(key)
    with _cmd(con, 'SETMETADATA') as (tag, start, complete):
        args = ' %s (%s %s)' % (box, key, json.dumps(value))
        start(args.encode() + CRLF)
        typ, data = complete()
        return check(con._untagged_response(typ, data, 'METADATA'))


@command(dovecot=True)
def getmetadata(con, box, key):
    key = _mdkey(key)
    with _cmd(con, 'GETMETADATA') as (tag, start, complete):
        args = ' %s (%s)' % (box, key)
        start(args.encode() + CRLF)
        typ, data = complete()
        return check(con._untagged_response(typ, data, 'METADATA'))


@command(dovecot=True)
def sieve(con, criteria, script):
    script = script.strip().encode()
    criteria = criteria.encode()
    with _cmd(con, 'FILTER') as (tag, start, complete):
        args = ' SIEVE SCRIPT {%s}' % len(script)
        start(args.encode() + CRLF)
        con.send(script)
        con.send(criteria + CRLF)
        typ, data = complete()
        log.debug('%s; %s', typ, data[0].decode())
        err = con.untagged_responses.pop('FILTER', None)
        if err:
            err = err[0][1]
            raise Error(err)
        if typ != 'OK':
            raise Error(typ, data)
        filtered = con.untagged_responses.pop('FILTERED', [])
        return filtered


def clean_recent(flags):
    if not flags:
        return flags

    if isinstance(flags, bytes):
        flags = flags.decode()
    return re.sub(r'(^| )\\Recent( |$)', ' ', flags)


def _multiappend(con, box, msgs):
    with _cmd(con, 'APPEND') as (tag, start, complete):
        send = start
        for date_time, flags, msg in msgs:
            flags = clean_recent(flags)
            if date_time is None:
                date_time = Time2Internaldate(time.time())
            args = (' (%s) %s %s' % (flags, date_time, '{%s}' % len(msg)))
            if send == start:
                args = ' %s %s' % (box, args)
            send(args.encode() + CRLF)
            send = con.send
            while con._get_response():
                bad = con.tagged_commands[tag]
                if bad:
                    raise Error(bad)
            con.send(msg)
        con.send(CRLF)
        res = check(complete())
        log.debug('%s', res[0].decode())
        uids = con.untagged_responses.pop('APPENDUID')
        uids = uids[0].decode().split(' ', 1)[-1]
        return uids


@command(dovecot=True, writable=True, lock=False)
def multiappend(con, box, msgs, *, batch=None, threads=10):
    if not msgs:
        return

    if batch and len(msgs) > batch:
        def multiappend_inner(num, few):
            with con.new() as c:
                res = multiappend(c, box, few)
                log.debug('#%s multiappend %s messages', num, len(few))
                return res

        pool = Pool(threads)
        jobs = [
            pool.spawn(multiappend_inner, num, msgs[i:i+batch])
            for num, i in enumerate(range(0, len(msgs), batch))
        ]
        pool.join(raise_error=True)
        return ','.join(j.value for j in jobs)

    with con.lock:
        return _multiappend(con, box, msgs)


@command(dovecot=True)
def thread(con, *criteria):
    res = check(con.uid('THREAD', *criteria))
    return parse_thread(res[0].decode() if res else '')


@command(dovecot=True)
def sort(con, fields, *criteria, charset='UTF-8'):
    res = check(con.uid('SORT', fields, charset, *criteria))
    return res[0].decode().split()


@command()
def idle(con, handlers, timeout=None):
    def match():
        for code, handler in handlers.items():
            typ, dat = con._untagged_response('OK', [None], code)
            if not dat[-1]:
                continue
            handler(dat)

    def inner(tag):
        with Timeout(timeout):
            res = con._get_response()
        if res:
            log.debug('received: %r', res.decode())
        bad = con.tagged_commands[tag]
        if bad:
            raise Error(bad)
        match()

    match()
    log.info('start idling %s...' % con)
    with _cmd(con, 'IDLE') as (tag, start, complete):
        clean_pool()
        start(CRLF)
        while 1:
            try:
                inner(tag)
            except Timeout:
                log.debug('timeout reached: %ss', timeout)
                return


@command(lock=False)
def enable(con, capability):
    return check(con.enable(capability))


@command()
def logout(con, timeout=1):
    with Timeout(timeout):
        try:
            return con.logout()
        except con.abort as e:
            log.error(e)


@command(name='list')
def xlist(con, folder='""', pattern='*'):
    return check(con.list(folder, pattern))


@command()
def select(con, box, readonly=True):
    res = check(con.select(box, readonly))
    con.current_box = box.decode() if isinstance(box, bytes) else box
    con.flags = con.untagged_responses['FLAGS'][0].decode()[1:-1].split()
    con.uidnext = int(con.untagged_responses['UIDNEXT'][0].decode())
    con.uidvalidity = con.untagged_responses['UIDVALIDITY'][0].decode()
    highestmodseq = int(con.untagged_responses['HIGHESTMODSEQ'][0].decode())
    con.highestmodseq = highestmodseq
    return res


@ft.lru_cache(None)
def find_folder(con, tag):
    if isinstance(tag, str):
        tag = tag.encode()
    folder = None
    folders = xlist(con)
    for f in folders:
        if not re.search(br'^\([^)]*?%s' % re.escape(tag), f):
            continue
        folder = f.rsplit(b' "/" ', 1)[1]
        break
    return folder, folders


@command(lock=False)
def select_tag(con, tag, readonly=True, exc=True):
    folder, folders = find_folder(con, tag)
    if folder is None:
        if exc:
            raise Error('No folder with tag: %s\n%s' % (tag, folders))
        return None
    return select(con, folder, readonly)


@command()
def status(con, box, fields):
    box = con.current_box if box is None else box
    return check(con.status(box, fields))


@command()
def search(con, *criteria):
    res = check(con.uid('SEARCH', None, *criteria))
    return res[0].decode().split()


@command(writable=True)
def append(con, box, flags, date_time, msg):
    check(con.append(box, clean_recent(flags), date_time, msg))
    uidlatest = con.untagged_responses.pop('APPENDUID')
    uidlatest = uidlatest[0].decode().split(' ', 1)[-1]
    # update "uidnext" because some stuff is relying on it
    # for example metadata cache
    con.uidnext = str(int(uidlatest) + 1)
    return uidlatest


@command(writable=True)
@cmd_writable
def expunge(con):
    return check(con.expunge())


@command()
def copy(con, uids, box):
    return check(con.uid('COPY', ','.join(uids), box))


@command(lock=False)
def fetch(con, uids, fields):
    uids = Uids(uids)
    if uids.batches:
        res = uids.call_async(fetch, con, uids, fields)
        return sum(res, [])

    desc = fn_desc(fetch, con, uids, fields)
    with con.lock:
        res = check(fn_time(con.uid, desc)('FETCH', uids.str, fields))
    if len(res) == 1 and res[0] is None:
        return []
    return res


@command(lock=False, writable=True)
@cmd_writable
def store(con, uids, cmd, flags):
    if not uids:
        return []

    flags = clean_recent(flags)
    if not flags:
        return []

    uids = Uids(uids)
    if uids.batches:
        res = uids.call_async(store, con, uids, cmd, flags)
        return sum(res, [])

    desc = fn_desc(store, con, uids, cmd, flags)
    with con.lock:
        res = check(fn_time(con.uid, desc)('STORE', uids.str, cmd, flags))
    if len(res) == 1 and res[0] is None:
        return []
    return res


class Threads(tuple):
    def __new__(cls, thrs, uids):
        obj = tuple.__new__(cls, thrs)
        obj.all_uids = uids
        return obj


def parse_thread(line):
    if isinstance(line, bytes):
        line = line.decode()

    threads = []
    all_uids = []
    uids = []
    uid = ''
    opening = 0
    for i in line:
        if i == '(':
            opening += 1
        elif i == ')':
            if uid:
                uids.append(uid)
                uid = ''

            opening -= 1
            if opening == 0:
                threads.append(uids)
                all_uids.extend(uids)
                uids = []
        elif i == ' ':
            uids.append(uid)
            uid = ''
        else:
            uid += i
    return Threads(threads, all_uids)


def pack_uids(uids):
    uids = sorted(int(i) for i in uids)
    result = ''
    for i, uid in enumerate(uids):
        if i == 0:
            result += str(uid)
        elif uid - uids[i-1] == 1:
            if len(uids) == (i + 1):
                if not result.endswith(':'):
                    result += ':'
                result += str(uid)
            elif result.endswith(':'):
                pass
            else:
                result += ':'
        elif result.endswith(':'):
            result += '%d,%d' % (uids[i-1], uid)
        else:
            result += ',%s' % uid
    return result


class Uids:
    __slots__ = ['val', 'batches', 'threads']

    def __init__(self, uids, *, batch=10000, threads=10):
        if isinstance(uids, Uids):
            uids = uids.val
        self.threads = threads
        self.val = uids
        self.batches = None
        if not self.is_str and len(uids) > batch:
            self.batches = tuple(
                Uids(uids[i:i+batch], batch=batch)
                for i in range(0, len(uids), batch)
            )

    @property
    def str(self):
        if self.is_str:
            return self.val
        return ','.join(str(i) for i in self.val)

    @property
    def is_str(self):
        return isinstance(self.val, (str, bytes))

    def _call(self, fn, *args):
        num, uids = [i for i in enumerate(args) if self == i[1]][0]
        args = list(args)
        for i, few in enumerate(uids.batches or ([self] if self.val else [])):
            args[num] = few
            f = ft.partial(fn, *args)
            desc = fn_desc(fn, *args)
            yield fn_time(f, '#%s %s' % (i, desc))

    def call(self, fn, *args):
        return [f() for f in self._call(fn, *args)]

    @fn_time
    def call_async(self, fn, *args):
        if not self.batches:
            return self.call(fn, *args)

        def get_exceptions():
            return [j.exception for j in jobs if j.exception]

        jobs = []
        pool = Pool(self.threads)
        for f in self._call(fn, *args):
            if pool.wait_available():
                if get_exceptions():
                    break
                jobs.append(pool.spawn(f))
        pool.join()

        exceptions = get_exceptions()
        if exceptions:
            raise ValueError('Exception in the pool: %s' % exceptions)
        return (f.value for f in jobs)

    def __repr__(self):
        return str(self)

    def __str__(self):
        fmt = '"%s uids"'
        if self.is_str:
            uids = self.val
            uids = uids if isinstance(uids, str) else uids.decode()
            return uids if ':' in uids else fmt % (uids.count(',') + 1)
        if len(self.val) < 5:
            # show few uids as is
            return str(self.val)
        return fmt % len(self.val)