import asyncio
import io
import json
import logging
import textwrap
import traceback
from contextlib import redirect_stdout

import aiohttp
import discord
import websockets
from discord.ext import commands
from motor.motor_asyncio import AsyncIOMotorClient

from essentials.messagecache import MessageCache
from essentials.multi_server import get_pre
from essentials.settings import SETTINGS


class ClusterBot(commands.AutoShardedBot):
    def __init__(self, **kwargs):
        self.pipe = kwargs.pop('pipe')
        self.cluster_name = kwargs.pop('cluster_name')
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        super().__init__(**kwargs, loop=loop)
        self.websocket = None
        self._last_result = None
        self.ws_task = None
        self.responses = asyncio.Queue()
        self.eval_wait = False
        log = logging.getLogger(f"Cluster#{self.cluster_name}")
        log.setLevel(logging.DEBUG)
        log.handlers = [logging.FileHandler(f'cluster-{self.cluster_name}.log', encoding='utf-8', mode='a')]

        log.info(f'[Cluster#{self.cluster_name}] {kwargs["shard_ids"]}, {kwargs["shard_count"]}')
        self.log = log

        self.owner = None
        self.db = None
        self.session = None
        self.emoji_dict = None
        self.pre = None

        self.remove_command('help')
        self.load_extension("cogs.eval")
        extensions = ['cogs.config', 'cogs.poll_controls', 'cogs.help', 'cogs.db_api', 'cogs.admin']
        for ext in extensions:
            self.load_extension(ext)

        self.message_cache = MessageCache(self)
        self.refresh_blocked = {}
        self.refresh_queue = {}

        self.loop.create_task(self.ensure_ipc())
        self.run(kwargs['token'])

    async def on_message(self, message):
        # allow case insensitive prefix
        prefix = await get_pre(self, message)
        if type(prefix) == tuple:
            prefixes = (t.lower() for t in prefix)
            for pfx in prefixes:
                if len(pfx) >= 1 and message.content.lower().startswith(pfx.lower()):
                    # print("Matching", message.content, "with", pfx)
                    message.content = pfx + message.content[len(pfx):]
                    await self.process_commands(message)
                    break
        else:
            if message.content.lower().startswith(prefix.lower()):
                message.content = prefix + message.content[len(prefix):]
                await self.process_commands(message)

    async def on_ready(self):
        self.owner = await self.fetch_user(SETTINGS.owner_id)
        mongo = AsyncIOMotorClient(SETTINGS.mongo_db)
        self.db = mongo.pollmaster
        self.session = aiohttp.ClientSession()
        with open('utils/emoji-compact.json', encoding='utf-8') as emojson:
            self.emoji_dict = json.load(emojson)
        self.pre = {entry['_id']: entry.get('prefix', 'pm!') async for entry in
                   self.db.config.find({}, {'_id', 'prefix'})}
        await self.change_presence(activity=discord.Activity(type=discord.ActivityType.listening, name="pm!help"))

        self.log.info(f'[Cluster#{self.cluster_name}] Ready called.')
        self.pipe.send(1)
        self.pipe.close()

    async def on_guild_join(self, server):
        result = await self.db.config.find_one({'_id': str(server.id)})
        if result is None:
            await self.db.config.update_one(
                {'_id': str(server.id)},
                {'$set': {'prefix': 'pm!', 'admin_role': 'polladmin', 'user_role': 'polluser'}},
                upsert=True
            )
            self.pre[str(server.id)] = 'pm!'

    async def on_shard_ready(self, shard_id):
        self.log.info(f'[Cluster#{self.cluster_name}] Shard {shard_id} ready')

    async def on_command_error(self, ctx, exc):
        if not isinstance(exc, (commands.CommandNotFound, commands.NotOwner)):
            self.log.critical(''.join(traceback.format_exception(type(exc), exc, exc.__traceback__)))
            # await ctx.send("check logs")

    async def on_error(self, *args, **kwargs):
        self.log.critical(traceback.format_exc())

    def cleanup_code(self, content):
        """Automatically removes code blocks from the code."""
        # remove ```py\n```
        if content.startswith('```') and content.endswith('```'):
            return '\n'.join(content.split('\n')[1:-1])

        # remove `foo`
        return content.strip('` \n')

    async def close(self, *args, **kwargs):
        self.log.info("shutting down")
        await self.websocket.close()
        await super().close()

    async def exec(self, code):
        env = {
            'bot': self,
            '_': self._last_result
        }

        env.update(globals())

        body = self.cleanup_code(code)
        stdout = io.StringIO()

        to_compile = f'async def func():\n{textwrap.indent(body, "  ")}'

        try:
            exec(to_compile, env)
        except Exception as e:
            return f'{e.__class__.__name__}: {e}'

        func = env['func']
        try:
            with redirect_stdout(stdout):
                ret = await func()
        except Exception as e:
            value = stdout.getvalue()
            f'{value}{traceback.format_exc()}'
        else:
            value = stdout.getvalue()

            if ret is None:
                if value:
                    return str(value)
                else:
                    return 'None'
            else:
                self._last_result = ret
                return f'{value}{ret}'

    async def websocket_loop(self):
        while True:
            try:
                msg = await self.websocket.recv()
            except websockets.ConnectionClosed as exc:
                if exc.code == 1000:
                    return
                raise
            data = json.loads(msg, encoding='utf-8')
            if self.eval_wait and data.get('response'):
                await self.responses.put(data)
            cmd = data.get('command')
            if not cmd:
                continue
            if cmd == 'ping':
                ret = {'response': 'pong'}
                self.log.info("received command [ping]")
            elif cmd == 'eval':
                self.log.info(f"received command [eval] ({data['content']})")
                content = data['content']
                data = await self.exec(content)
                ret = {'response': str(data)}
            else:
                ret = {'response': 'unknown command'}
            ret['author'] = self.cluster_name
            self.log.info(f"responding: {ret}")
            try:
                await self.websocket.send(json.dumps(ret).encode('utf-8'))
            except websockets.ConnectionClosed as exc:
                if exc.code == 1000:
                    return
                raise

    async def ensure_ipc(self):
        self.websocket = w = await websockets.connect('ws://localhost:42069')
        await w.send(self.cluster_name.encode('utf-8'))
        try:
            await w.recv()
            self.ws_task = self.loop.create_task(self.websocket_loop())
            self.log.info("ws connection succeeded")
        except websockets.ConnectionClosed as exc:
            self.log.warning(f"! couldnt connect to ws: {exc.code} {exc.reason}")
            self.websocket = None
            raise