import asyncio
import asyncio.subprocess
import binascii
import copy
import hashlib
import hmac
import os
import socket
import subprocess
import sys
import time
from contextlib import closing

import aiohttp
import pytest
from aiohttp import web
from aiohttp.web_urldispatcher import UrlDispatcher

import threema.gateway
from threema.gateway import e2e
from threema.gateway.key import Key

# Turn off deprecation warnings for now
# TODO: Port code to async/await
os.environ['PYTHONWARNINGS'] = 'ignore'

_res_path = os.path.normpath(os.path.join(
    os.path.abspath(__file__), os.pardir, 'res'))

class RawMessage(e2e.Message):
    def __init__(self, connection, nonce=None, message=None, **kwargs):
        super().__init__(connection, e2e.Message.Type.text_message, **kwargs)
        self.nonce = nonce
        self.message = message

    def pack(self, writer):
        raise NotImplementedError

    def unpack(cls, connection, parameters, key_pair, reader):
        raise NotImplementedError

    def send(self, get_data_only=False):
        Send the raw message

        Return the ID of the message.
        # Send message
        if get_data_only:
            return self.nonce, self.message
            return (yield from self._connection.send_e2e(**{
                'to': self.to_id,
                'nonce': binascii.hexlify(self.nonce).decode(),
                'box': binascii.hexlify(self.message).decode()

class Server:
    def __init__(self):
        self.threema_jpg = os.path.join(_res_path, 'threema.jpg')
        self.threema_mp4 = os.path.join(_res_path, 'threema.mp4')
        key = b'4a6a1b34dcef15d43cb74de2fd36091be99fbbaf126d099d47d83d919712c72b'
        self.echoecho_key = key
        self.echoecho_encoded_key = 'public:' + key.decode('ascii')
        decoded_private_key = Key.decode(pytest.msgapi.private, Key.Type.private)
        self.mocking_key = Key.derive_public(decoded_private_key).hex_pk()
        self.blobs = {}
        self.latest_blob_ids = []

        router = UrlDispatcher()
        router.add_route('GET', '/pubkeys/{key}', self.pubkeys)
        router.add_route('GET', '/lookup/phone/{phone}', self.lookup_phone)
        router.add_route('GET', '/lookup/phone_hash/{phone_hash}', self.lookup_phone_hash)
        router.add_route('GET', '/lookup/email/{email}', self.lookup_email)
        router.add_route('GET', '/lookup/email_hash/{email_hash}', self.lookup_email_hash)
        router.add_route('GET', '/capabilities/{id}', self.capabilities)
        router.add_route('GET', '/credits', self.credits)
        router.add_route('POST', '/send_simple', self.send_simple)
        router.add_route('POST', '/send_e2e', self.send_e2e)
        router.add_route('POST', '/upload_blob', self.upload_blob)
        router.add_route('GET', '/blobs/{blob_id}', self.download_blob)
        self.router = router

    def pubkeys(self, request):
        key = request.match_info['key']
        from_, secret = request.query['from'], request.query['secret']
        if (from_, secret) not in pytest.msgapi.api_identities:
            return web.Response(status=401)
        elif len(key) != 8:
            return web.Response(status=404)
        elif key == 'ECHOECHO':
            return web.Response(body=self.echoecho_key)
        elif key == '*MOCKING':
            return web.Response(body=self.mocking_key)
        return web.Response(status=404)

    def lookup_phone(self, request):
        phone = request.match_info['phone']
        from_, secret = request.query['from'], request.query['secret']
        if (from_, secret) not in pytest.msgapi.api_identities:
            return web.Response(status=401)
        elif not phone.isdigit():
            return web.Response(status=404)
        elif phone == '44123456789':
            return web.Response(body=b'ECHOECHO')
        return web.Response(status=404)

    def lookup_phone_hash(self, request):
        phone_hash = request.match_info['phone_hash']
        from_, secret = request.query['from'], request.query['secret']
        hash_ = '98b05f6eda7a878f6f016bdcdc9db6eb61a6b190e814ff787142115af144214c'
        if (from_, secret) not in pytest.msgapi.api_identities:
            return web.Response(status=401)
        elif len(phone_hash) % 2 != 0:
            # Note: This status code might not be intended and may change in the future
            return web.Response(status=500)
        elif len(phone_hash) != 64:
            return web.Response(status=400)
        elif phone_hash == hash_:
            return web.Response(body=b'ECHOECHO')
        return web.Response(status=404)

    def lookup_email(self, request):
        email = request.match_info['email']
        from_, secret = request.query['from'], request.query['secret']
        if (from_, secret) not in pytest.msgapi.api_identities:
            return web.Response(status=401)
        elif email == '':
            return web.Response(body=b'ECHOECHO')
        return web.Response(status=404)

    def lookup_email_hash(self, request):
        email_hash = request.match_info['email_hash']
        from_, secret = request.query['from'], request.query['secret']
        hash_ = '45a13d422b40f81936a9987245d3f6d9064c90607273af4f578246b4484669e2'
        if (from_, secret) not in pytest.msgapi.api_identities:
            return web.Response(status=401)
        elif len(email_hash) % 2 != 0:
            # Note: This status code might not be intended and may change in the future
            return web.Response(status=500)
        elif len(email_hash) != 64:
            return web.Response(status=400)
        elif email_hash == hash_:
            return web.Response(body=b'ECHOECHO')
        return web.Response(status=404)

    def capabilities(self, request):
        id_ = request.match_info['id']
        from_, secret = request.query['from'], request.query['secret']
        if (from_, secret) not in pytest.msgapi.api_identities:
            return web.Response(status=401)
        elif id_ == 'ECHOECHO':
            return web.Response(body=b'text,image,video,file')
        elif id_ == '*MOCKING':
            return web.Response(body=b'text,image,video,file')
        return web.Response(status=404)

    def credits(self, request):
        from_, secret = request.query['from'], request.query['secret']
        if (from_, secret) not in pytest.msgapi.api_identities:
            return web.Response(status=401)
        return web.Response(body=b'100')

    def send_simple(self, request):
        post = (yield from

        # Check API identity
        if (post['from'], post['secret']) not in pytest.msgapi.api_identities:
            return web.Response(status=401)

        # Get ID from to, email or phone
        if 'to' in post:
            id_ = post['to']
        elif post.get('email', None) == '':
            id_ = 'ECHOECHO'
        elif post.get('phone', None) == '44123456789':
            id_ = 'ECHOECHO'
            return web.Response(status=404)

        # Process
        text = post['text']
        if post['from'] == pytest.msgapi.nocredit_id:
            return web.Response(status=402)
        elif id_ != 'ECHOECHO':
            return web.Response(status=400)
        elif len(text) > 3500:
            return web.Response(status=413)
        return web.Response(body=b'0' * 16)

    def send_e2e(self, request):
        post = (yield from

        # Check API identity
        if (post['from'], post['secret']) not in pytest.msgapi.api_identities:
            return web.Response(status=401)

        # Get ID, nonce and box
        id_ = post['to']
        nonce, box = binascii.unhexlify(post['nonce']), binascii.unhexlify(post['box'])

        # Process
        if post['from'] == pytest.msgapi.nocredit_id:
            return web.Response(status=402)
        elif id_ != 'ECHOECHO':
            return web.Response(status=400)
        elif len(nonce) != 24:
            # Note: This status code might not be intended and may change in the future
            return web.Response(status=400)
        elif len(box) > 4000:
            return web.Response(status=413)
        return web.Response(body=b'1' * 16)

    def upload_blob(self, request):
            data = (yield from

            # Check API identity
            api_identity = (request.query['from'], request.query['secret'])
            if api_identity not in pytest.msgapi.api_identities:
                return web.Response(status=401)
        except KeyError:
            return web.Response(status=401)

            # Get blob
            blob = data['blob']
        except KeyError:
            # Note: This status code might not be intended and may change in the future
            return web.Response(status=500)

        # Generate ID
        blob_id = hashlib.md5(blob).hexdigest()

        # Process
        if request.query['from'] == pytest.msgapi.nocredit_id:
            return web.Response(status=402)
        elif len(blob) == 0:
            return web.Response(status=400)
        elif len(blob) > 20 * (2**20):
            return web.Response(status=413)

        # Store blob and return
        self.blobs[blob_id] = blob
        return web.Response(body=blob_id.encode())

    def download_blob(self, request):
        blob_id = request.match_info['blob_id']

        # Check API identity
        from_, secret = request.query['from'], request.query['secret']
        if (from_, secret) not in pytest.msgapi.api_identities:
            return web.Response(status=401)

        # Get blob
            blob = self.blobs[blob_id]
        except KeyError:
            return web.Response(status=404)
            return web.Response(

def pytest_addoption(parser):
    help_ = 'loop: Use a different event loop, supported: asyncio, uvloop'
    parser.addoption("--loop", action="store", help=help_)

def pytest_report_header(config):
    return 'Using event loop: {}'.format(default_event_loop(config=config))

def pytest_namespace():
    private = 'private:dd9413d597092b004fedc4895db978425efa328ba1f1ec6729e46e09231b8a7e'
    public = Key.encode(Key.derive_public(Key.decode(private, Key.Type.private)))
    values = {'msgapi': {
        'cli_path': os.path.join(
        'cert_path': os.path.join(_res_path, 'cert.pem'),
        'base_url': '',
        'ip': '',
        'id': '*MOCKING',
        'secret': 'mock',
        'private': private,
        'public': public,
        'nocredit_id': 'NOCREDIT',
        'noexist_id': '*NOEXIST',
    values['msgapi']['api_identities'] = {
        (values['msgapi']['id'], values['msgapi']['secret']),
        (values['msgapi']['nocredit_id'], values['msgapi']['secret'])
    return values

def default_event_loop(request=None, config=None):
    if request is not None:
        config = request.config
    loop = config.getoption("--loop")
    if loop == 'uvloop':
        import uvloop
        loop = 'asyncio'
    return loop

def unused_tcp_port():
    Find an unused localhost TCP port from 1024-65535 and return it.
    with closing(socket.socket()) as sock:
        sock.bind((pytest.msgapi.ip, 0))
        return sock.getsockname()[1]

def identity():
    return, pytest.msgapi.secret

def server():
    return Server()

def raw_message():
    return RawMessage

def event_loop(request):
    Create an instance of the requested event loop.

    # Close previous event loop
    policy = asyncio.get_event_loop_policy()

    # Create new event loop
    _event_loop = policy.new_event_loop()

    def fin():

    # Add finaliser and return new event loop
    return _event_loop

def api_server_port():
    return unused_tcp_port()

def api_server(request, event_loop, api_server_port, server):
    port = api_server_port
    app = web.Application(
        loop=event_loop, router=server.router, client_max_size=100 * (2**20))
    handler = app.make_handler()

    # Set up server
    coroutine = event_loop.create_server(handler, host=pytest.msgapi.ip, port=port)
    server_ = event_loop.run_until_complete(coroutine)

    def fin():


def mock_url(api_server_port):
    Return the URL where the test server can be reached.
    return 'http://{}:{}'.format(pytest.msgapi.ip, api_server_port)

def connection(request, api_server, mock_url):
    # Note: We're not doing anything with the server but obviously the
    # server needs to be started to be able to connect
    connection_ = threema.gateway.Connection(,

    # Patch URLs
    connection_.urls = {key: value.replace(pytest.msgapi.base_url, mock_url)
                        for key, value in connection_.urls.items()}

    def fin():

    return connection_

def connection_blocking(request, api_server, mock_url):
    # Note: We're not doing anything with the server but obviously the
    # server needs to be started to be able to connect
    connection_ = threema.gateway.Connection(,

    # Patch URLs
    connection_.urls = {key: value.replace(pytest.msgapi.base_url, mock_url)
                        for key, value in connection_.urls.items()}

    def fin():

    return connection_

def invalid_connection(connection):
    invalid_connection_ = copy.copy(connection) = pytest.msgapi.noexist_id
    return invalid_connection_

def nocredit_connection(connection):
    nocredit_connection_ = copy.copy(connection) = pytest.msgapi.nocredit_id
    return nocredit_connection_

def blob():
    return b'\x01\x02\x03'

def blob_id(event_loop, connection, blob):
    coroutine = connection.upload(blob)
    return event_loop.run_until_complete(coroutine)

def cli(api_server, api_server_port, event_loop):
    def call_cli(*args, input=None, timeout=3.0):
        # Prepare environment
        env = os.environ.copy()
        env['THREEMA_TEST_API'] = str(api_server_port)
        test_api_mode = 'WARNING: Currently running in test mode!'

        # Call CLI in subprocess and get output
        parameters = [sys.executable, pytest.msgapi.cli_path] + list(args)
        if isinstance(input, str):
            input = input.encode('utf-8')

        # Create process
        create = asyncio.create_subprocess_exec(
            *parameters, env=env, stdin=asyncio.subprocess.PIPE,
            stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.STDOUT)
        process = yield from create

        # Wait for process to terminate
        coroutine = process.communicate(input=input)
        output, _ = yield from asyncio.wait_for(coroutine, timeout, loop=event_loop)

        # Process output
        output = output.decode('utf-8')
        if test_api_mode not in output:
            raise ValueError('Not running in test mode')

        # Strip leading empty lines and pydev debugger output
        rubbish = [
            'pydev debugger: process',
            'Traceback (most recent call last):',
        lines = []
        skip_following_empty_lines = True
        for line in output.splitlines(keepends=True):
            if any((line.startswith(s) for s in rubbish)):
                skip_following_empty_lines = True
            elif not skip_following_empty_lines or len(line.strip()) > 0:
                skip_following_empty_lines = False

        # Strip trailing empty lines
        empty_lines_count = 0
        for line in reversed(lines):
            if len(line.strip()) > 0:
            empty_lines_count += 1
        if empty_lines_count > 0:
            lines = lines[:-empty_lines_count]
        output = ''.join(lines)

        # Check return code
        if process.returncode != 0:
            raise subprocess.CalledProcessError(process.returncode, parameters,
        return output
    return call_cli

def private_key_file(tmpdir_factory):
    file = tmpdir_factory.mktemp('keys').join('private_key')
    return str(file)

def public_key_file(tmpdir_factory):
    file = tmpdir_factory.mktemp('keys').join('public_key')
    return str(file)

class Callback(e2e.AbstractCallback):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.queue = asyncio.Queue(loop=self.loop)

    def receive_message(self, message):
        yield from self.queue.put(message)

def callback(event_loop, connection):
    return Callback(connection, loop=event_loop)

def callback_server_port():
    return unused_tcp_port()

def callback_server(request, event_loop, callback, callback_server_port):
    cert_path = pytest.msgapi.cert_path
    server_ = event_loop.run_until_complete(callback.create_server(
        certfile=cert_path, host=pytest.msgapi.ip, port=callback_server_port))

    def fin():


def callback_client(request, event_loop, callback_server):
    # Note: This is ONLY required because we are using a self-signed certificate
    #       for test purposes.
    connector = aiohttp.TCPConnector(verify_ssl=False)
    session = aiohttp.ClientSession(connector=connector, loop=event_loop)

    def fin():

    return session

def callback_send(callback_client, callback_server_port, connection):
    def send(message):
        # Get data from message
        nonce, data = yield from message.send(get_data_only=True)

        # Create callback parameters
        params = {
            'to': message.to_id,
            'messageId': hashlib.md5(message.to_id.encode('ascii')).hexdigest()[16:],
            'date': str(time.time()),
            'nonce': binascii.hexlify(nonce).decode('ascii'),
            'box': binascii.hexlify(data).decode('ascii'),

        # Calculate MAC
        message = ''.join((params['from'], params['to'], params['messageId'],
                           params['date'], params['nonce'], params['box']))
        message = message.encode('ascii')
        encoded_secret = connection.secret.encode('ascii')
        hmac_ =, msg=message, digestmod=hashlib.sha256)
        params['mac'] = hmac_.hexdigest()

        # Send message
        url = 'https://{}:{}/gateway_callback'.format(
            pytest.msgapi.ip, callback_server_port)
        return (yield from, data=params))

    return send

def callback_receive(event_loop, callback, callback_server):
    def receive(timeout=3.0):
        coroutine = asyncio.wait_for(callback.queue.get(), timeout, loop=event_loop)
        return (yield from coroutine)

    return receive