# -*- coding: utf-8 -*-
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from os import environ
from time import time_ns
from io import BytesIO
from threading import Condition
from typing import BinaryIO, Optional, Union

import attr
from twisted.internet.defer import ensureDeferred
from twisted.test.proto_helpers import MemoryReactorClock
from twisted.trial import unittest
from twisted.web.http_headers import Headers
from twisted.web.server import Request

import psycopg2

from sygnal.http import PushGatewayApiServer
from sygnal.sygnal import CONFIG_DEFAULTS, Sygnal, merge_left_with_defaults

REQ_PATH = b"/_matrix/push/v1/notify"

USE_POSTGRES = environ.get("TEST_USE_POSTGRES", False)
# the dbname we will connect to in order to create the base database.
POSTGRES_DBNAME_FOR_INITIAL_CREATE = "postgres"
POSTGRES_USER = environ.get("TEST_POSTGRES_USER", None)
POSTGRES_PASSWORD = environ.get("TEST_POSTGRES_PASSWORD", None)
POSTGRES_HOST = environ.get("TEST_POSTGRES_HOST", None)


class TestCase(unittest.TestCase):
    def config_setup(self, config):
        self.dbname = "_sygnal_%s" % (time_ns())
        if USE_POSTGRES:
            config["database"] = {
                "name": "psycopg2",
                "args": {
                    "user": POSTGRES_USER,
                    "password": POSTGRES_PASSWORD,
                    "database": self.dbname,
                    "host": POSTGRES_HOST,
                },
            }
        else:
            config["database"] = {"name": "sqlite3", "args": {"dbfile": ":memory:"}}

    def _set_up_database(self, dbname):
        conn = psycopg2.connect(
            database=POSTGRES_DBNAME_FOR_INITIAL_CREATE,
            user=POSTGRES_USER,
            password=POSTGRES_PASSWORD,
            host=POSTGRES_HOST,
        )
        conn.autocommit = True
        cur = conn.cursor()
        cur.execute("DROP DATABASE IF EXISTS %s;" % (dbname,))
        cur.execute("CREATE DATABASE %s;" % (dbname,))
        cur.close()
        conn.close()

    def _tear_down_database(self, dbname):
        conn = psycopg2.connect(
            database=POSTGRES_DBNAME_FOR_INITIAL_CREATE,
            user=POSTGRES_USER,
            password=POSTGRES_PASSWORD,
            host=POSTGRES_HOST,
        )
        conn.autocommit = True
        cur = conn.cursor()
        cur.execute("DROP DATABASE %s;" % (dbname,))
        cur.close()
        conn.close()

    def setUp(self):
        reactor = ExtendedMemoryReactorClock()

        config = {"apps": {}, "log": {"setup": {"version": 1}}}

        self.config_setup(config)

        config = merge_left_with_defaults(CONFIG_DEFAULTS, config)
        if USE_POSTGRES:
            self._set_up_database(self.dbname)

        self.sygnal = Sygnal(config, reactor)
        self.sygnal.database.start()
        self.v1api = PushGatewayApiServer(self.sygnal)

        start_deferred = ensureDeferred(
            self.sygnal._make_pushkins_then_start(0, [], None)
        )

        while not start_deferred.called:
            # we need to advance until the pushkins have started up
            self.sygnal.reactor.advance(1)
            self.sygnal.reactor.wait_for_work(lambda: start_deferred.called)

    def tearDown(self):
        super().tearDown()
        self.sygnal.database.close()
        if USE_POSTGRES:
            self._tear_down_database(self.dbname)

    def _make_dummy_notification(self, devices):
        return {
            "notification": {
                "id": "$3957tyerfgewrf384",
                "room_id": "!slw48wfj34rtnrf:example.com",
                "event_id": "$qTOWWTEL48yPm3uT-gdNhFcoHxfKbZuqRVnnWWSkGBs",
                "type": "m.room.message",
                "sender": "@exampleuser:matrix.org",
                "sender_display_name": "Major Tom",
                "room_name": "Mission Control",
                "room_alias": "#exampleroom:matrix.org",
                "prio": "high",
                "content": {
                    "msgtype": "m.text",
                    "body": "I'm floating in a most peculiar way.",
                },
                "counts": {"unread": 2, "missed_calls": 1},
                "devices": devices,
            }
        }

    def _make_dummy_notification_event_id_only(self, devices):
        return {
            "notification": {
                "room_id": "!slw48wfj34rtnrf:example.com",
                "event_id": "$qTOWWTEL48yPm3uT-gdNhFcoHxfKbZuqRVnnWWSkGBs",
                "counts": {"unread": 2},
                "devices": devices,
            }
        }

    def _make_dummy_notification_badge_only(self, devices):
        return {
            "notification": {
                "id": "",
                "type": None,
                "sender": "",
                "counts": {"unread": 2},
                "devices": devices,
            }
        }

    def _request(self, payload) -> Union[dict, int]:
        """
        Make a dummy request to the notify endpoint with the specified payload

        Args:
            payload: payload to be JSON encoded

        Returns (dict or int):
            If successful (200 response received), the response is JSON decoded
            and the resultant dict is returned.
            If the response code is not 200, returns the response code.
        """
        if isinstance(payload, dict):
            payload = json.dumps(payload)
        content = BytesIO(payload.encode())

        channel = FakeChannel(self.v1api.site, self.sygnal.reactor)
        channel.process_request(b"POST", REQ_PATH, content)

        while not channel.done:
            # we need to advance until the request has been finished
            self.sygnal.reactor.advance(1)
            self.sygnal.reactor.wait_for_work(lambda: channel.done)

        assert channel.done

        if channel.result.code != 200:
            return channel.result.code

        return json.loads(channel.response_body)


class ExtendedMemoryReactorClock(MemoryReactorClock):
    def __init__(self):
        super().__init__()
        self.work_notifier = Condition()

    def callFromThread(self, function, *args):
        self.callLater(0, function, *args)

    def callLater(self, when, what, *a, **kw):
        self.work_notifier.acquire()
        try:
            return_value = super().callLater(when, what, *a, **kw)
            self.work_notifier.notify_all()
        finally:
            self.work_notifier.release()

        return return_value

    def wait_for_work(self, early_stop=lambda: False):
        """
        Blocks until there is work as long as the early stop condition
        is not satisfied.

        Args:
            early_stop: Extra function called that determines whether to stop
                blocking.
                Should returns true iff the early stop condition is satisfied,
                in which case no blocking will be done.
                It is intended to be used to detect when the task you are
                waiting for is complete, e.g. a Deferred has fired or a
                Request has been finished.
        """
        self.work_notifier.acquire()

        try:
            while len(self.getDelayedCalls()) == 0 and not early_stop():
                self.work_notifier.wait()
        finally:
            self.work_notifier.release()


class DummyResponse(object):
    def __init__(self, code):
        self.code = code


def make_async_magic_mock(ret_val):
    async def dummy(*_args, **_kwargs):
        return ret_val

    return dummy


@attr.s
class HTTPResult:
    """Holds the result data for FakeChannel"""

    version = attr.ib(type=str)
    code = attr.ib(type=int)
    reason = attr.ib(type=str)
    headers = attr.ib(type=Headers)


@attr.s
class FakeChannel(object):
    """
    A fake Twisted Web Channel (the part that interfaces with the
    wire).
    """

    site = attr.ib()
    _reactor = attr.ib()
    _producer = None

    result = attr.ib(type=Optional[HTTPResult], default=None)
    response_body = b""
    done = attr.ib(type=bool, default=False)

    @property
    def code(self):
        if not self.result:
            raise Exception("No result yet.")
        return int(self.result.code)

    def writeHeaders(self, version, code, reason, headers):
        self.result = HTTPResult(version, int(code), reason, headers)

    def write(self, content):
        assert isinstance(content, bytes), "Should be bytes! " + repr(content)
        self.response_body += content

    def requestDone(self, _self):
        self.done = True

    def getPeer(self):
        return None

    def getHost(self):
        return None

    @property
    def transport(self):
        return None

    def process_request(self, method: bytes, request_path: bytes, content: BinaryIO):
        """pretend that a request has arrived, and process it"""

        # this is normally done by HTTPChannel, in its various lineReceived etc methods
        req = self.site.requestFactory(self)  # type: Request
        req.content = content
        req.requestReceived(method, request_path, b"1.1")