import asyncio
import base64
import random
from time import time
from unittest.mock import patch

import pytest

from peony import exceptions, oauth

from . import dummy


@pytest.fixture
def oauth1_headers():
    return oauth.OAuth1Headers("1234567890", "0987654321", "aaaa", "bbbb")


def dummy_func(arg, **kwargs):
    return arg


class MockClient:

    def __init__(self):
        self.count = 0

    def __getitem__(self, key):
        return self

    def __getattr__(self, key):
        return self

    async def post(self, _data=None, _headers=None, **kwargs):
        self.count += 1

        if _data is not None:
            with patch.object(oauth.aiohttp.payload, 'BytesPayload',
                              side_effect=dummy_func):
                assert _data._gen_form_urlencoded() == b"access_token=abc"

        if _headers is not None:
            key = "1234567890:0987654321"
            auth = base64.b64encode(key.encode('utf-8')).decode('utf-8')
            assert _headers['Authorization'] == 'Basic ' + auth

        # This is needed to run `test_oauth2_concurrent_refreshes`
        # without that, refresh tasks would be executed sequentially
        # In a sense it is a simulation of a request being fetched
        await asyncio.sleep(0.001)
        return {'access_token': "abc"}

    def url(self):
        return ""


@pytest.fixture
def oauth2_headers():
    return oauth.OAuth2Headers("1234567890", "0987654321", client=MockClient())


def test_oauth1_gen_nonce(oauth1_headers):
    assert oauth1_headers.gen_nonce() != oauth1_headers.gen_nonce()


def test_oauth1_signature(oauth1_headers):
    signature = oauth1_headers.gen_signature(method='GET',
                                             url="http://whatever.com",
                                             params={'hello': "world"},
                                             skip_params=False,
                                             oauth={})

    assert "v3nfF8OWWLfGTuKhi7075R1BBGE=" == signature


def test_oauth1_signature_no_token():
    headers = oauth.OAuth1Headers("1234567890", "0987654321")

    signature = headers.gen_signature(method='GET',
                                      url="http://whatever.com",
                                      params={'hello': "world"},
                                      skip_params=False,
                                      oauth={})

    assert "Q9XX4OvdvoOb8ZJyXPrhWiYwOzk=" == signature


def test_oauth1_signature_queries_safe_chars(oauth1_headers):
    query = "@twitter hello :) $:!?/()'*@"

    signature = oauth1_headers.gen_signature(method='GET',
                                             url="http://whatever.com",
                                             params={'q': query},
                                             skip_params=False,
                                             oauth={'Header': "hello"})

    assert "ah8dUnveaRVMFisNXKScS6Wy2kU=" == signature


def test_oauth1_sign(oauth1_headers):
    t = time()

    with patch.object(oauth.time, 'time', return_value=t):
        random.seed(0)
        headers = oauth1_headers.sign(method='POST',
                                      url='http://whatever.com',
                                      data={'hello': "world"})

    random.seed(0)
    nonce = oauth1_headers.gen_nonce()
    oauth_headers = {
        'oauth_consumer_key': oauth1_headers.consumer_key,
        'oauth_nonce': nonce,
        'oauth_signature_method': 'HMAC-SHA1',
        'oauth_timestamp': str(int(t)),
        'oauth_version': '1.0',
        'oauth_token': "aaaa"
    }

    signature = oauth1_headers.gen_signature(method='POST',
                                             url='http://whatever.com',
                                             params={'hello': "world"},
                                             skip_params=False,
                                             oauth=oauth_headers)

    expected = ('OAuth oauth_consumer_key="1234567890", '
                'oauth_nonce="{nonce}", '
                'oauth_signature="{signature}", '
                'oauth_signature_method="HMAC-SHA1", '
                'oauth_timestamp="{time}", '
                'oauth_token="aaaa", '
                'oauth_version="1.0"'.format(nonce=nonce,
                                             signature=oauth.quote(signature),
                                             time=int(t)))

    assert expected == headers['Authorization']


@pytest.mark.parametrize('headers,key', [
    (None, 'data'),
    (None, 'params'),
    ({'Content-Type': "application/x-www-form-urlencoded"}, 'data')
])
def test_oauth1_sign_skip_params(oauth1_headers, headers, key):
    t = time()

    with patch.object(oauth.time, 'time', return_value=t):
        random.seed(0)
        kwargs = {
            'method': 'POST',
            'url': "http://whatever.com",
            key: {'hello': "world"},
            'skip_params': True,
            'headers': headers
        }
        headers = oauth1_headers.sign(**kwargs)

    random.seed(0)
    nonce = oauth1_headers.gen_nonce()
    oauth_headers = {
        'oauth_consumer_key': oauth1_headers.consumer_key,
        'oauth_nonce': nonce,
        'oauth_signature_method': 'HMAC-SHA1',
        'oauth_timestamp': str(int(t)),
        'oauth_version': '1.0',
        'oauth_token': "aaaa"
    }

    signature = oauth1_headers.gen_signature(method='POST',
                                             url='http://whatever.com',
                                             params={'hello': "world"},
                                             skip_params=True,
                                             oauth=oauth_headers)

    expected = ('OAuth oauth_consumer_key="1234567890", '
                'oauth_nonce="{nonce}", '
                'oauth_signature="{signature}", '
                'oauth_signature_method="HMAC-SHA1", '
                'oauth_timestamp="{time}", '
                'oauth_token="aaaa", '
                'oauth_version="1.0"'.format(nonce=nonce,
                                             signature=oauth.quote(signature),
                                             time=int(t)))

    assert expected == headers['Authorization']


def test_headers_options():
    client = oauth.OAuth1Headers("", "",
                                 user_agent="Awesome app",
                                 compression=False,
                                 headers={'Custom': "abc"})
    assert client['User-Agent'] == "Awesome app"
    assert 'Accept-Encoding' not in client
    assert client['Custom'] == "abc"


@pytest.mark.asyncio
async def test_prepare_request(oauth1_headers):
    async def mock_sign():  # no need to test sign again here
        return oauth1_headers.copy()

    with patch.object(oauth1_headers, 'sign') as sign:
        sign.return_value = mock_sign()
        kwargs = await oauth1_headers.prepare_request(
            method='GET', url="http://whatever.com", params={'test': 'hello'}
        )

        sign.return_value = mock_sign()
        kwargs_post = await oauth1_headers.prepare_request(
            method='POST', url="http://whatever.com", data={'test': 'hello'}
        )

        sign.return_value = mock_sign()
        kwargs_no_params = await oauth1_headers.prepare_request(
            method='get', url="http://whatever.com"
        )

    assert 'params' in kwargs
    assert 'data' in kwargs_post

    assert kwargs['method'] == 'GET'
    assert kwargs_post['method'] == 'POST'

    assert kwargs['url'] == kwargs['url'] == "http://whatever.com"

    assert kwargs_post['headers'] == kwargs['headers'] == oauth1_headers.copy()
    assert kwargs_post['data'] == kwargs['params'] == {'test': 'hello'}

    assert 'data' not in kwargs
    assert 'params' not in kwargs_post
    assert 'data' not in kwargs_no_params and 'params' not in kwargs_no_params


def test_user_headers(oauth2_headers):
    oauth2_headers.token = "abc"
    headers = oauth2_headers._user_headers({'Authorization': "cba"})
    assert headers['Authorization'] == "Bearer abc"

    del oauth2_headers.token
    headers = oauth2_headers._user_headers({'Authorization': "cba"})
    assert headers['Authorization'] == "cba"


def test_oauth2_set_token():
    oauth2 = oauth.OAuth2Headers("123", "456", client=None, bearer_token="abc")
    assert oauth2.token == "abc"


@pytest.mark.asyncio
async def test_oauth2_refresh_token(oauth2_headers):
    assert oauth2_headers.token is None

    await oauth2_headers.refresh_token()
    assert oauth2_headers.token == "abc"


@pytest.mark.asyncio
async def test_oauth2_sign(oauth2_headers):
    with patch.object(oauth2_headers, 'refresh_token',
                      side_effect=dummy) as refresh_token:
        await oauth2_headers.sign(url='http://whatever.com')
        assert refresh_token.called

    await oauth2_headers.sign(url='http://whatever.com')
    assert oauth2_headers.token == "abc"
    with patch.object(oauth2_headers, 'refresh_token') as refresh_token:
        await oauth2_headers.sign(url='http://whatever.com')
        assert not refresh_token.called


@pytest.mark.asyncio
async def test_oauth2_sign_url_invalidate(oauth2_headers):
    oauth2_headers.token = "test"
    await oauth2_headers.sign(url=oauth2_headers._invalidate_token.url())
    assert oauth2_headers.token is None


@pytest.mark.asyncio
async def test_oauth2_concurrent_refreshes(oauth2_headers):
    assert oauth2_headers.client.count == 0

    async def refresh():
        await oauth2_headers.refresh_token()

    await asyncio.gather(refresh(), refresh())
    assert oauth2_headers.client.count == 1


def test_raw_form_data():

    with patch.object(oauth.aiohttp.payload, 'BytesPayload',
                      side_effect=dummy_func):
        formdata = oauth.RawFormData({'access_token': "a%20bc%25",
                                      'access_token_secret': "cba"},
                                     quote_fields=False)
        data = formdata._gen_form_urlencoded()
        assert data == b"access_token=a%20bc%25&access_token_secret=cba"


@pytest.mark.asyncio
async def test_oauth2_invalidate_token_no_token(oauth2_headers):
    with pytest.raises(RuntimeError):
        await oauth2_headers.invalidate_token()


@pytest.mark.asyncio
async def test_oauth2_invalidate_token_exception(oauth2_headers):
    def rexc(**kwargs):
        raise exceptions.PeonyException

    with pytest.raises(exceptions.PeonyException):
        with patch.object(oauth2_headers.client, 'post', side_effect=rexc):
            oauth2_headers.token = "abc"
            await oauth2_headers.invalidate_token()