#!/usr/bin/env python

import sys
from functools import wraps

    import urllib
    import urlparse
except ImportError:
    import urllib.parse as urllib
    from urllib import parse as urlparse

import struct
import time
import posixpath
import re
import hashlib
import socket
import random
from base64 import b64encode

    from googleapiclient.discovery import build
    from googleapiclient.errors import HttpError
except ImportError:
    from apiclient.discovery import build
    from apiclient.errors import HttpError

import logging
from ._version import get_versions

__version__ = get_versions()['version']
del get_versions

log = logging.getLogger('gglsbl')

_fail_count = 0

def autoretry(func):
    def wrapper(*args, **kwargs):
        global _fail_count
        while True:
                r = func(*args, **kwargs)
                _fail_count = 0
                return r
            except HttpError as e:
                if not (hasattr(e, 'resp') and 'status' in e.resp
                        and e.resp['status'].isdigit and int(e.resp['status']) >= 500):
                    raise  # we do not want to retry auth errors etc.
                _fail_count += 1
                wait_for = min(2 ** (_fail_count - 1) * 15 * 60 * (1 + random.random()), 24 * 60 * 60)
                log.exception('Call Failed for %s time(s). Retrying in %s seconds: %s',
                              _fail_count, wait_for, str(e))
            except socket.error:
                transient_error_wait = 2
                log.exception('Socket error, retrying in {} seconds.'.format(transient_error_wait))
    return wrapper

class SafeBrowsingApiClient(object):
    def __init__(self, developer_key, client_id='python-gglsbl',
                 client_version=__version__, discard_fair_use_policy=True):

        :param developer_key: Google API key
        :param discard_fair_use_policy: do not wait between individual API calls as requested by the spec
        self.client_id = client_id
        self.client_version = client_version
        self.discard_fair_use_policy = discard_fair_use_policy
        if self.discard_fair_use_policy:
            log.warn('Circumventing request frequency throttling is against Safe Browsing API policy.')
        self.service = build('safebrowsing', 'v4', developerKey=developer_key, cache_discovery=False)
        self.next_request_no_sooner_than = None

    def set_wait_duration(self, minimum_wait_duration):
        if self.discard_fair_use_policy:
        if minimum_wait_duration is None:
            self.next_request_no_sooner_than = None
        self.next_request_no_sooner_than = time.time() + float(minimum_wait_duration.rstrip('s'))

    def fair_use_delay(self):
        if self.next_request_no_sooner_than is not None:
            sleep_for = max(0, self.next_request_no_sooner_than - time.time())
            log.info('Sleeping for {} seconds until next request.'.format(sleep_for))

    def get_threats_lists(self):
        """Retrieve all available threat lists"""
        response = self.service.threatLists().list().execute()
        return response['threatLists']

    def get_threats_update(self, client_state):
        """Fetch hash prefixes update for given threat list.

        client_state is a dict which looks like {(threatType, platformType, threatEntryType): clientState}
        request_body = {
            "client": {
                "clientId": self.client_id,
                "clientVersion": self.client_version,
            "listUpdateRequests": [],
        for (threat_type, platform_type, threat_entry_type), current_state in client_state.items():
                    "threatType": threat_type,
                    "platformType": platform_type,
                    "threatEntryType": threat_entry_type,
                    "state": current_state,
                    "constraints": {
                        "supportedCompressions": ["RAW"]
        response = self.service.threatListUpdates().fetch(body=request_body).execute()
        return response['listUpdateResponses']

    def get_full_hashes(self, prefixes, client_state):
        """Find full hashes matching hash prefixes.

        client_state is a dict which looks like {(threatType, platformType, threatEntryType): clientState}
        request_body = {
            "client": {
                "clientId": self.client_id,
                "clientVersion": self.client_version,
            "clientStates": [],
            "threatInfo": {
                "threatTypes": [],
                "platformTypes": [],
                "threatEntryTypes": [],
                "threatEntries": [],
        for prefix in prefixes:
            request_body['threatInfo']['threatEntries'].append({"hash": b64encode(prefix).decode()})
        for ((threatType, platformType, threatEntryType), clientState) in client_state.items():
            if threatType not in request_body['threatInfo']['threatTypes']:
            if platformType not in request_body['threatInfo']['platformTypes']:
            if threatEntryType not in request_body['threatInfo']['threatEntryTypes']:
        response = self.service.fullHashes().find(body=request_body).execute()
        return response

class URL(object):
    """URL representation suitable for lookup"""

    __py3 = (sys.version_info > (3, 0))

    def __init__(self, url):

        :param url: can be either of str or bytes type.
        if self.__py3:
            if type(url) is bytes:
                self.url = bytes(url)
                self.url = url.encode()
            self.url = str(url)

    def hashes(self):
        """Hashes of all possible permutations of the URL in canonical form"""
        for url_variant in self.url_permutations(self.canonical):
            url_hash = self.digest(url_variant)
            yield url_hash

    def canonical(self):
        """Convert URL to its canonical form."""
        def full_unescape(u):
            uu = urllib.unquote(u)
            if uu == u:
                return uu
                return full_unescape(uu)

        def full_unescape_to_bytes(u):
            uu = urlparse.unquote_to_bytes(u)
            if uu == u:
                return uu
                return full_unescape_to_bytes(uu)

        def quote(s):
            safe_chars = '!"$&\'()*+,-./:;<=>?@[\\]^_`{|}~'
            return urllib.quote(s, safe=safe_chars)

        url = self.url.strip()
        url = url.replace(b'\n', b'').replace(b'\r', b'').replace(b'\t', b'')
        url = url.split(b'#', 1)[0]
        if url.startswith(b'//'):
            url = b'http:' + url
        if len(url.split(b'://')) <= 1:
            url = b'http://' + url
        # at python3 work with bytes instead of string
        # as URL may contain invalid unicode characters
        if self.__py3 and type(url) is bytes:
            url = quote(full_unescape_to_bytes(url))
            url = quote(full_unescape(url))
        url_parts = urlparse.urlsplit(url)
        if not url_parts[0]:
            url = 'http://{}'.format(url)
            url_parts = urlparse.urlsplit(url)
        protocol = url_parts.scheme
        if self.__py3:
            host = full_unescape_to_bytes(url_parts.hostname)
            path = full_unescape_to_bytes(url_parts.path)
            host = full_unescape(url_parts.hostname)
            path = full_unescape(url_parts.path)
        query = url_parts.query
        if not query and '?' not in url:
            query = None
        if not path:
            path = b'/'
        has_trailing_slash = (path[-1:] == b'/')
        path = posixpath.normpath(path).replace(b'//', b'/')
        if has_trailing_slash and path[-1:] != b'/':
            path = path + b'/'
        port = url_parts.port
        host = host.strip(b'.')
        host = re.sub(br'\.+', b'.', host).lower()
        if host.isdigit():
                host = socket.inet_ntoa(struct.pack("!I", int(host)))
            except Exception:
        elif host.startswith(b'0x') and b'.' not in host:
                host = socket.inet_ntoa(struct.pack("!I", int(host, 16)))
            except Exception:
        quoted_path = quote(path)
        quoted_host = quote(host)
        if port is not None:
            quoted_host = '{}:{}'.format(quoted_host, port)
        canonical_url = '{}://{}{}'.format(protocol, quoted_host, quoted_path)
        if query is not None:
            canonical_url = '{}?{}'.format(canonical_url, query)
        return canonical_url

    def url_permutations(url):
        """Try all permutations of hostname and path which can be applied

        to blacklisted URLs
        def url_host_permutations(host):
            if re.match(r'\d+\.\d+\.\d+\.\d+', host):
                yield host
            parts = host.split('.')
            l = min(len(parts), 5)
            if l > 4:
                yield host
            for i in range(l - 1):
                yield '.'.join(parts[i - l:])

        def url_path_permutations(path):
            yield path
            query = None
            if '?' in path:
                path, query = path.split('?', 1)
            if query is not None:
                yield path
            path_parts = path.split('/')[0:-1]
            curr_path = ''
            for i in range(min(4, len(path_parts))):
                curr_path = curr_path + path_parts[i] + '/'
                yield curr_path

        protocol, address_str = urllib.splittype(url)
        host, path = urllib.splithost(address_str)
        user, host = urllib.splituser(host)
        host, port = urllib.splitport(host)
        host = host.strip('/')
        seen_permutations = set()
        for h in url_host_permutations(host):
            for p in url_path_permutations(path):
                u = '{}{}'.format(h, p)
                if u not in seen_permutations:
                    yield u

    def digest(url):
        """Hash the URL"""
        return hashlib.sha256(url.encode('utf-8')).digest()

if __name__ == '__main__':
    from pprint import pprint
    c = SafeBrowsingApiClient('AIzaSyATpqLltciaMve61Wywb5yNDA8D8BvXEn4')
    r = c.get_threats_lists()