import functools
from collections import defaultdict, namedtuple
from inspect import getframeinfo, stack
from typing import Any, Callable, Dict, List, Optional, cast

from click import echo
from eth_typing.evm import BlockNumber, ChecksumAddress, HexAddress
from eth_utils import to_checksum_address
from web3 import Web3
from web3._utils.events import get_event_data
from web3._utils.filters import LogFilter as Web3LogFilter, construct_event_filter_params
from web3._utils.threads import Timeout

# A concrete event added in a transaction.
from web3.types import ABI, ABIEvent, BlockIdentifier

LogRecorded = namedtuple("LogRecorded", "message callback count")
GenesisBlock = BlockNumber(0)


class LogHandler:
    def __init__(self, web3: Web3, address: HexAddress, abi: ABI):
        self.web3 = web3
        self.address = address
        self.abi = abi
        self.event_waiting: Dict[str, Dict[str, LogRecorded]] = {}
        self.event_filters: Dict[str, LogFilter] = {}
        self.event_count: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(lambda: 0))
        self.event_unknown: List[Dict[str, Any]] = []

    def add(
        self,
        txn_hash: str,
        event_name: str,
        callback: Optional[Callable[..., Any]] = None,
        count: int = 1,
    ) -> None:
        caller = getframeinfo(stack()[1][0])
        message = "%s:%d" % (caller.filename, caller.lineno)

        if event_name not in self.event_waiting:
            self.event_waiting[event_name] = {}
            self.event_filters[event_name] = LogFilter(
                web3=self.web3,
                abi=self.abi,
                address=to_checksum_address(self.address),
                event_name=event_name,
                callback=self.handle_log,
            )

        self.event_waiting[event_name][txn_hash] = LogRecorded(
            message=message, callback=callback, count=count
        )

    def check(self, timeout: int = 5) -> None:
        for event in list(self.event_filters.keys()):
            self.event_filters[event].init()

        self.wait(timeout)

    def _handle_waited_log(self, event: Dict[str, Any]) -> None:
        """ A subroutine of handle_log
        Increment self.event_count, forget about waiting, and call the callback if any.
        """
        txn_hash = event["transactionHash"]
        event_name = event["event"]
        assert event_name in self.event_waiting
        assert txn_hash in self.event_waiting[event_name]

        self.event_count[event_name][txn_hash] += 1
        event_entry = self.event_waiting[event_name][txn_hash]

        if event_entry.count == self.event_count[event_name][txn_hash]:
            self.event_waiting[event_name].pop(txn_hash)

        # Call callback function with event
        if event_entry.callback:
            event_entry.callback(event)

    def handle_log(self, event: Dict[str, Any]) -> None:
        txn_hash = event["transactionHash"]
        event_name = event["event"]

        if event_name in self.event_waiting:
            if txn_hash in self.event_waiting[event_name]:
                self._handle_waited_log(event)
            else:
                self.event_unknown.append(event)
            if not len(list(self.event_waiting[event_name].keys())):
                self.event_waiting.pop(event_name, None)
                self.event_filters.pop(event_name, None)

    def wait(self, seconds: int) -> None:
        try:
            with Timeout(seconds) as timeout:
                while len(list(self.event_waiting.keys())):
                    timeout.sleep(2)
        except Exception as e:
            echo(e, err=True)
            message = "NO EVENTS WERE TRIGGERED FOR: " + str(self.event_waiting)
            if len(self.event_unknown) > 0:
                message += "\n UNKOWN EVENTS: " + str(self.event_unknown)

            # FIXME Events triggered in an internal transaction
            # don't have the transactionHash we are looking for here
            # so we just check if the number of unknown events we find
            # is the same as the found events
            waiting_events = sum([len(lst) for lst in self.event_waiting.values()])

            if waiting_events == len(self.event_unknown):
                sandwitch_echo(message)
            else:
                raise Exception(
                    message + " waiting_events " + str(waiting_events),
                    " len(self.event_unknown) " + str(len(self.event_unknown)),
                )

    def assert_event(
        self, txn_hash: str, event_name: str, args: List[Any], timeout: int = 5
    ) -> None:
        """ Assert that `event_name` is emitted with the `args`

        For use in tests only.
        """

        def assert_args(event: Dict[str, Any]) -> None:
            assert event["args"] == args, f'{event["args"]} == {args}'

        self.add(txn_hash=txn_hash, event_name=event_name, callback=assert_args)
        self.check(timeout=timeout)


def sandwitch_echo(msg: str) -> None:
    echo("----------------------------------", err=True)
    echo(msg, err=True)
    echo("----------------------------------", err=True)


class LogFilter:
    def __init__(
        self,
        web3: Web3,
        abi: ABI,
        address: ChecksumAddress,
        event_name: str,
        from_block: BlockNumber = GenesisBlock,
        to_block: BlockIdentifier = "latest",
        filters: Any = None,
        callback: Optional[Callable[..., Any]] = None,
    ):
        self.web3 = web3
        self.event_name = event_name

        # Callback for every registered log
        self.callback = callback

        event_abi = [i for i in abi if i["type"] == "event" and i["name"] == event_name]
        if len(event_abi) == 0:
            raise ValueError(f"Event of name {event_name} not found")

        self.event_abi = cast(ABIEvent, event_abi[0])
        assert self.event_abi

        filters = filters if filters else {}

        data_filter_set, filter_params = construct_event_filter_params(
            event_abi=self.event_abi,
            abi_codec=web3.codec,
            contract_address=address,
            argument_filters=filters,
            fromBlock=from_block,
            toBlock=to_block,
        )
        log_data_extract_fn = functools.partial(get_event_data, web3.codec, event_abi)

        self.filter: Web3LogFilter = web3.eth.filter(filter_params)  # type: ignore
        self.filter.set_data_filters(data_filter_set)  # type: ignore
        self.filter.log_entry_formatter = log_data_extract_fn
        self.filter.filter_params = filter_params

    def init(self, post_callback: Optional[Callable[[], None]] = None) -> None:
        for log in self.get_logs():
            log["event"] = self.event_name
            if self.callback:
                self.callback(log)
        if post_callback:
            post_callback()

    def get_logs(self) -> List[Any]:
        assert self.filter.filter_id is not None
        logs = self.web3.eth.getFilterLogs(self.filter.filter_id)
        formatted_logs = []
        for log in [dict(log) for log in logs]:
            formatted_logs.append(self.set_log_data(log))
        return formatted_logs

    def set_log_data(self, log: Dict[str, Any]) -> Dict[str, Any]:
        log["args"] = get_event_data(
            abi_codec=self.web3.codec, event_abi=self.event_abi, log_entry=log
        )["args"]
        log["event"] = self.event_name
        return log

    def uninstall(self) -> None:
        assert self.web3 is not None
        assert self.filter is not None
        assert self.filter.filter_id is not None
        self.web3.eth.uninstallFilter(self.filter.filter_id)
        del self.filter