#!/usr/bin/env python

import sys
from functools import wraps

try:
    import urllib
    import urlparse
except ImportError:
    import urllib.parse as urllib
    from urllib import parse as urlparse

import struct
import time
import posixpath
import re
import hashlib
import socket
import random
from base64 import b64encode

try:
    from googleapiclient.discovery import build
    from googleapiclient.errors import HttpError
except ImportError:
    from apiclient.discovery import build
    from apiclient.errors import HttpError

import logging
from ._version import get_versions


__version__ = get_versions()['version']
del get_versions

log = logging.getLogger('gglsbl')
log.addHandler(logging.NullHandler())


_fail_count = 0


def autoretry(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        global _fail_count
        while True:
            try:
                r = func(*args, **kwargs)
                _fail_count = 0
                return r
            except HttpError as e:
                if not (hasattr(e, 'resp') and 'status' in e.resp
                        and e.resp['status'].isdigit and int(e.resp['status']) >= 500):
                    raise  # we do not want to retry auth errors etc.
                _fail_count += 1
                wait_for = min(2 ** (_fail_count - 1) * 15 * 60 * (1 + random.random()), 24 * 60 * 60)
                log.exception('Call Failed for %s time(s). Retrying in %s seconds: %s',
                              _fail_count, wait_for, str(e))
                time.sleep(wait_for)
            except socket.error:
                transient_error_wait = 2
                log.exception('Socket error, retrying in {} seconds.'.format(transient_error_wait))
                time.sleep(transient_error_wait)
    return wrapper


class SafeBrowsingApiClient(object):
    def __init__(self, developer_key, client_id='python-gglsbl',
                 client_version=__version__, discard_fair_use_policy=True):
        """Constructor.

        :param developer_key: Google API key
        :param discard_fair_use_policy: do not wait between individual API calls as requested by the spec
        """
        self.client_id = client_id
        self.client_version = client_version
        self.discard_fair_use_policy = discard_fair_use_policy
        if self.discard_fair_use_policy:
            log.warn('Circumventing request frequency throttling is against Safe Browsing API policy.')
        self.service = build('safebrowsing', 'v4', developerKey=developer_key, cache_discovery=False)
        self.next_request_no_sooner_than = None

    def set_wait_duration(self, minimum_wait_duration):
        if self.discard_fair_use_policy:
            return
        if minimum_wait_duration is None:
            self.next_request_no_sooner_than = None
            return
        self.next_request_no_sooner_than = time.time() + float(minimum_wait_duration.rstrip('s'))

    def fair_use_delay(self):
        if self.next_request_no_sooner_than is not None:
            sleep_for = max(0, self.next_request_no_sooner_than - time.time())
            log.info('Sleeping for {} seconds until next request.'.format(sleep_for))
            time.sleep(sleep_for)

    @autoretry
    def get_threats_lists(self):
        """Retrieve all available threat lists"""
        response = self.service.threatLists().list().execute()
        self.set_wait_duration(response.get('minimumWaitDuration'))
        return response['threatLists']

    @autoretry
    def get_threats_update(self, client_state):
        """Fetch hash prefixes update for given threat list.

        client_state is a dict which looks like {(threatType, platformType, threatEntryType): clientState}
        """
        request_body = {
            "client": {
                "clientId": self.client_id,
                "clientVersion": self.client_version,
            },
            "listUpdateRequests": [],
        }
        for (threat_type, platform_type, threat_entry_type), current_state in client_state.items():
            request_body['listUpdateRequests'].append(
                {
                    "threatType": threat_type,
                    "platformType": platform_type,
                    "threatEntryType": threat_entry_type,
                    "state": current_state,
                    "constraints": {
                        "supportedCompressions": ["RAW"]
                    }
                }
            )
        response = self.service.threatListUpdates().fetch(body=request_body).execute()
        self.set_wait_duration(response.get('minimumWaitDuration'))
        return response['listUpdateResponses']

    @autoretry
    def get_full_hashes(self, prefixes, client_state):
        """Find full hashes matching hash prefixes.

        client_state is a dict which looks like {(threatType, platformType, threatEntryType): clientState}
        """
        request_body = {
            "client": {
                "clientId": self.client_id,
                "clientVersion": self.client_version,
            },
            "clientStates": [],
            "threatInfo": {
                "threatTypes": [],
                "platformTypes": [],
                "threatEntryTypes": [],
                "threatEntries": [],
            }
        }
        for prefix in prefixes:
            request_body['threatInfo']['threatEntries'].append({"hash": b64encode(prefix).decode()})
        for ((threatType, platformType, threatEntryType), clientState) in client_state.items():
            request_body['clientStates'].append(clientState)
            if threatType not in request_body['threatInfo']['threatTypes']:
                request_body['threatInfo']['threatTypes'].append(threatType)
            if platformType not in request_body['threatInfo']['platformTypes']:
                request_body['threatInfo']['platformTypes'].append(platformType)
            if threatEntryType not in request_body['threatInfo']['threatEntryTypes']:
                request_body['threatInfo']['threatEntryTypes'].append(threatEntryType)
        response = self.service.fullHashes().find(body=request_body).execute()
        self.set_wait_duration(response.get('minimumWaitDuration'))
        return response


class URL(object):
    """URL representation suitable for lookup"""

    __py3 = (sys.version_info > (3, 0))

    def __init__(self, url):
        """Constructor.

        :param url: can be either of str or bytes type.
        """
        if self.__py3:
            if type(url) is bytes:
                self.url = bytes(url)
            else:
                self.url = url.encode()
        else:
            self.url = str(url)

    @property
    def hashes(self):
        """Hashes of all possible permutations of the URL in canonical form"""
        for url_variant in self.url_permutations(self.canonical):
            url_hash = self.digest(url_variant)
            yield url_hash

    @property
    def canonical(self):
        """Convert URL to its canonical form."""
        def full_unescape(u):
            uu = urllib.unquote(u)
            if uu == u:
                return uu
            else:
                return full_unescape(uu)

        def full_unescape_to_bytes(u):
            uu = urlparse.unquote_to_bytes(u)
            if uu == u:
                return uu
            else:
                return full_unescape_to_bytes(uu)

        def quote(s):
            safe_chars = '!"$&\'()*+,-./:;<=>?@[\\]^_`{|}~'
            return urllib.quote(s, safe=safe_chars)

        url = self.url.strip()
        url = url.replace(b'\n', b'').replace(b'\r', b'').replace(b'\t', b'')
        url = url.split(b'#', 1)[0]
        if url.startswith(b'//'):
            url = b'http:' + url
        if len(url.split(b'://')) <= 1:
            url = b'http://' + url
        # at python3 work with bytes instead of string
        # as URL may contain invalid unicode characters
        if self.__py3 and type(url) is bytes:
            url = quote(full_unescape_to_bytes(url))
        else:
            url = quote(full_unescape(url))
        url_parts = urlparse.urlsplit(url)
        if not url_parts[0]:
            url = 'http://{}'.format(url)
            url_parts = urlparse.urlsplit(url)
        protocol = url_parts.scheme
        if self.__py3:
            host = full_unescape_to_bytes(url_parts.hostname)
            path = full_unescape_to_bytes(url_parts.path)
        else:
            host = full_unescape(url_parts.hostname)
            path = full_unescape(url_parts.path)
        query = url_parts.query
        if not query and '?' not in url:
            query = None
        if not path:
            path = b'/'
        has_trailing_slash = (path[-1:] == b'/')
        path = posixpath.normpath(path).replace(b'//', b'/')
        if has_trailing_slash and path[-1:] != b'/':
            path = path + b'/'
        port = url_parts.port
        host = host.strip(b'.')
        host = re.sub(br'\.+', b'.', host).lower()
        if host.isdigit():
            try:
                host = socket.inet_ntoa(struct.pack("!I", int(host)))
            except Exception:
                pass
        elif host.startswith(b'0x') and b'.' not in host:
            try:
                host = socket.inet_ntoa(struct.pack("!I", int(host, 16)))
            except Exception:
                pass
        quoted_path = quote(path)
        quoted_host = quote(host)
        if port is not None:
            quoted_host = '{}:{}'.format(quoted_host, port)
        canonical_url = '{}://{}{}'.format(protocol, quoted_host, quoted_path)
        if query is not None:
            canonical_url = '{}?{}'.format(canonical_url, query)
        return canonical_url

    @staticmethod
    def url_permutations(url):
        """Try all permutations of hostname and path which can be applied

        to blacklisted URLs
        """
        def url_host_permutations(host):
            if re.match(r'\d+\.\d+\.\d+\.\d+', host):
                yield host
                return
            parts = host.split('.')
            l = min(len(parts), 5)
            if l > 4:
                yield host
            for i in range(l - 1):
                yield '.'.join(parts[i - l:])

        def url_path_permutations(path):
            yield path
            query = None
            if '?' in path:
                path, query = path.split('?', 1)
            if query is not None:
                yield path
            path_parts = path.split('/')[0:-1]
            curr_path = ''
            for i in range(min(4, len(path_parts))):
                curr_path = curr_path + path_parts[i] + '/'
                yield curr_path

        protocol, address_str = urllib.splittype(url)
        host, path = urllib.splithost(address_str)
        user, host = urllib.splituser(host)
        host, port = urllib.splitport(host)
        host = host.strip('/')
        seen_permutations = set()
        for h in url_host_permutations(host):
            for p in url_path_permutations(path):
                u = '{}{}'.format(h, p)
                if u not in seen_permutations:
                    yield u
                    seen_permutations.add(u)

    @staticmethod
    def digest(url):
        """Hash the URL"""
        return hashlib.sha256(url.encode('utf-8')).digest()


if __name__ == '__main__':
    from pprint import pprint
    c = SafeBrowsingApiClient('AIzaSyATpqLltciaMve61Wywb5yNDA8D8BvXEn4')
    r = c.get_threats_lists()
    pprint(r)