from collections import deque
from typing import *

import asyncio
import collections
import itertools
import sys
import time
from typing import Any

import hiredis
import uvloop


expiration = collections.defaultdict(lambda: float("inf"))  # type: Dict[bytes, float]
dictionary = {}  # type: Dict[bytes, Any]


class RedisProtocol(asyncio.Protocol):
    def __init__(self):
        self.dictionary = dictionary
        self.response = collections.deque()
        self.parser = hiredis.Reader()
        self.transport = None  # type: asyncio.transports.Transport
        self.commands = {
            b"COMMAND": self.command,
            b"SET": self.set,
            b"GET": self.get,
            b"PING": self.ping,
            b"INCR": self.incr,
            b"LPUSH": self.lpush,
            b"RPUSH": self.rpush,
            b"LPOP": self.lpop,
            b"RPOP": self.rpop,
            b"SADD": self.sadd,
            b"HSET": self.hset,
            b"SPOP": self.spop,
            b"LRANGE": self.lrange,
            b"MSET": self.mset,
        }

    def connection_made(self, transport: asyncio.transports.Transport):
        self.transport = transport

    def data_received(self, data: bytes):
        self.parser.feed(data)

        while 1:
            req = self.parser.gets()
            if req is False:
                break
            else:
                self.response.append(self.commands[req[0].upper()](*req[1:]))

        self.transport.writelines(self.response)
        self.response.clear()

    def command(self):
        # Far from being a complete implementation of the `COMMAND` command of
        # Redis, yet sufficient for us to start using redis-cli.
        return b"+OK\r\n"

    def set(self, *args) -> bytes:
        # Defaults
        key = args[0]
        value = args[1]
        expires_at = None
        cond = b""

        largs = len(args)
        if largs == 3:
            # SET key value [NX|XX]
            cond = args[2]
        elif largs >= 4:
            # SET key value [EX seconds | PX milliseconds] [NX|XX]
            try:
                if args[2] == b"EX":
                    duration = int(args[3])
                elif args[2] == b"PX":
                    duration = int(args[3]) / 1000
                else:
                    return b"-ERR syntax error\r\n"
            except ValueError:
                return b"-value is not an integer or out of range\r\n"

            if duration <= 0:
                return b"-ERR invalid expire time in set\r\n"

            expires_at = time.monotonic() + duration

            if largs == 5:
                cond = args[4]

        if cond == b"":
            pass
        elif cond == b"NX":
            if key in self.dictionary:
                return b"$-1\r\n"
        elif cond == b"XX":
            if key not in self.dictionary:
                return b"$-1\r\n"
        else:
            return b"-ERR syntax error\r\n"

        if expires_at:
            expiration[key] = expires_at

        self.dictionary[key] = value
        return b"+OK\r\n"

    def get(self, key: bytes) -> bytes:
        if key not in self.dictionary:
            return b"$-1\r\n"

        if key in expiration and expiration[key] < time.monotonic():
            del self.dictionary[key]
            del expiration[key]
            return b"$-1\r\n"
        else:
            value = self.dictionary[key]
            return b"$%d\r\n%s\r\n" % (len(value), value)

    def ping(self, message=b"PONG"):
        return b"$%d\r\n%s\r\n" % (len(message), message)

    def incr(self, key):
        value = self.dictionary.get(key, 0)
        if type(value) is str:
            try:
                value = int(value)
            except ValueError:
                return b"-value is not an integer or out of range\r\n"
        value += 1
        self.dictionary[key] = str(value)
        return b":%d\r\n" % (value,)

    def lpush(self, key, *values):
        deque = self.dictionary.get(key, collections.deque())
        deque.extendleft(values)
        self.dictionary[key] = deque
        return b":%d\r\n" % (len(deque),)

    def rpush(self, key, *values):
        deque = self.dictionary.get(key, collections.deque())
        deque.extend(values)
        self.dictionary[key] = deque
        return b":%d\r\n" % (len(deque),)

    def lpop(self, key):
        try:
            deque = self.dictionary[key]  # type: collections.deque
        except KeyError:
            return b"$-1\r\n"
        value = deque.popleft()
        return b"$%d\r\n%s\r\n" % (len(value), value)

    def rpop(self, key):
        try:
            deque = self.dictionary[key]  # type: collections.deque
        except KeyError:
            return b"$-1\r\n"
        value = deque.pop()
        return b"$%d\r\n%s\r\n" % (len(value), value)

    def sadd(self, key, *members):
        set_ = self.dictionary.get(key, set())
        prev_size = len(set_)
        for member in members:
            set_.add(member)
        self.dictionary[key] = set_
        return b":%d\r\n" % (len(set_) - prev_size,)

    def hset(self, key, field, value):
        hash_ = self.dictionary.get(key, {})
        ret = int(field in hash_)
        hash_[field] = value
        self.dictionary[key] = hash_
        return b":%d\r\n" % (ret,)

    def spop(self, key):  # TODO add `count`
        try:
            set_ = self.dictionary[key]  # type: set
            elem = set_.pop()
        except KeyError:
            return b"$-1\r\n"
        return b"$%d\r\n%s\r\n" % (len(elem), elem)

    def lrange(self, key, start, stop):
        start = int(start)
        stop = int(stop)
        try:
            deque = self.dictionary[key]  # type: collections.deque
        except KeyError:
            return b"$-1\r\n"
        l = itertools.islice(deque, start, stop)
        return b"*%d\r\n%s" % (stop - start, b"".join(b"$%d\r\n%s\r\n" % (len(e), e) for e in l))

    def mset(self, *args):
        for i in range(0, len(args), 2):
            key = args[i]
            value = args[i + 1]
            self.dictionary[key] = value
        return b"+OK\r\n"


def main() -> int:
    print("Hello, World!")

    asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

    loop = asyncio.get_event_loop()
    # Each client connection will create a new protocol instance
    coro = loop.create_server(RedisProtocol, "127.0.0.1", 7878)
    server = loop.run_until_complete(coro)

    # Serve requests until Ctrl+C is pressed
    print('Serving on {}'.format(server.sockets[0].getsockname()))
    try:
        loop.run_forever()
    except KeyboardInterrupt:
        pass

    # Close the server
    server.close()
    loop.run_until_complete(server.wait_closed())
    loop.close()

    return 0


if __name__ == "__main__":
    sys.exit(main())