"""Exposes the full ``requests`` HTTP library API, while adding an extra
``family`` parameter to all HTTP request operations that may be used to restrict
the address family used when resolving a domain-name to an IP address.
"""
import socket
import urllib.parse

import requests
import requests.adapters
import urllib3
import urllib3.connection
import urllib3.exceptions
import urllib3.poolmanager
import urllib3.util.connection

AF2NAME = {
	int(socket.AF_INET):  "ip4",
	int(socket.AF_INET6): "ip6",
}
NAME2AF = {name: af for af, name in AF2NAME.items()}


# This function is copied from urllib3/util/connection.py (that in turn copied
# it from socket.py in the Python 2.7 standard library test suite) and accepts
# an extra `family` parameter that specifies the allowed address families for
# name resolution.
#
# The entire remainder of this file after this only exists to ensure that this
# `family` parameter is exposed all the way up to request's `Session` interface,
# storing it as part of the URL scheme while traversing most of the layers.
def create_connection(address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
                      source_address=None, socket_options=None,
                      family=socket.AF_UNSPEC):
	host, port = address
	if host.startswith('['):
		host = host.strip('[]')
	err = None

	if not family or family == socket.AF_UNSPEC:
		family = urllib3.util.connection.allowed_gai_family()

	for res in socket.getaddrinfo(host, port, family, socket.SOCK_STREAM):
		af, socktype, proto, canonname, sa = res
		sock = None
		try:
			sock = socket.socket(af, socktype, proto)

			# If provided, set socket level options before connecting.
			if socket_options is not None:
				for opt in socket_options:
					sock.setsockopt(*opt)

			if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT:
				sock.settimeout(timeout)
			if source_address:
				sock.bind(source_address)
			sock.connect(sa)
			return sock
		except OSError as e:
			err = e
			if sock is not None:
				sock.close()
				sock = None

	if err is not None:
		raise err

	raise OSError("getaddrinfo returns an empty list")


# Override the `urllib3` low-level Connection objects that do the actual work
# of speaking HTTP
def _kw_scheme_to_family(kw, base_scheme):
	family = socket.AF_UNSPEC
	scheme = kw.pop("scheme", None)
	if isinstance(scheme, str):
		parts = scheme.rsplit("+", 1)
		if len(parts) == 2 and parts[0] == base_scheme:
			family = NAME2AF.get(parts[1], family)
	return family


class ConnectionOverrideMixin:
	def _new_conn(self):
		extra_kw = {
			"family": self.family
		}
		if self.source_address:
			extra_kw['source_address'] = self.source_address

		if self.socket_options:
			extra_kw['socket_options'] = self.socket_options

		try:
			dns_host = getattr(self, "_dns_host", self.host)
			conn = create_connection(
				(dns_host, self.port), self.timeout, **extra_kw)
		except socket.timeout:
			raise urllib3.exceptions.ConnectTimeoutError(
				self, "Connection to %s timed out. (connect timeout=%s)" %
				(self.host, self.timeout))
		except OSError as e:
			raise urllib3.exceptions.NewConnectionError(
				self, "Failed to establish a new connection: %s" % e)

		return conn


class HTTPConnection(ConnectionOverrideMixin, urllib3.connection.HTTPConnection):
	def __init__(self, *args, **kw):
		self.family = _kw_scheme_to_family(kw, "http")
		super().__init__(*args, **kw)


class HTTPSConnection(ConnectionOverrideMixin, urllib3.connection.HTTPSConnection):
	def __init__(self, *args, **kw):
		self.family = _kw_scheme_to_family(kw, "https")
		super().__init__(*args, **kw)


# Override the higher-level `urllib3` ConnectionPool objects that instantiate
# one or more Connection objects and dispatch work between them
class HTTPConnectionPool(urllib3.HTTPConnectionPool):
	ConnectionCls = HTTPConnection


class HTTPSConnectionPool(urllib3.HTTPSConnectionPool):
	ConnectionCls = HTTPSConnection


# Override the highest-level `urllib3` PoolManager to also properly support the
# address family extended scheme values in URLs and pass these scheme values on
# to the individual ConnectionPool objects
class PoolManager(urllib3.PoolManager):
	def __init__(self, *args, **kwargs):
		super().__init__(*args, **kwargs)
		
		# Additionally to adding our variant of the usual HTTP and HTTPS
		# pool classes, also add these for some variants of the default schemes
		# that are limited to some specific address family only
		self.pool_classes_by_scheme = {}
		for scheme, ConnectionPool in (("http", HTTPConnectionPool), ("https", HTTPSConnectionPool)):
			self.pool_classes_by_scheme[scheme] = ConnectionPool
			for name in AF2NAME.values():
				self.pool_classes_by_scheme["{0}+{1}".format(scheme, name)] = ConnectionPool
				self.key_fn_by_scheme["{0}+{1}".format(scheme, name)] = self.key_fn_by_scheme[scheme]

	# These next two are only required to ensure that our custom `scheme` values
	# will be passed down to the `*ConnectionPool`s and finally to the actual
	# `*Connection`s as parameter
	def _new_pool(self, scheme, host, port, request_context=None):
		# Copied from `urllib3` to *not* surpress the `scheme` parameter
		pool_cls = self.pool_classes_by_scheme[scheme]
		if request_context is None:
			request_context = self.connection_pool_kw.copy()
		
		for key in ('host', 'port'):
			request_context.pop(key, None)
		
		if scheme == "http" or scheme.startswith("http+"):
			for kw in urllib3.poolmanager.SSL_KEYWORDS:
				request_context.pop(kw, None)
		
		return pool_cls(host, port, **request_context)

	def connection_from_pool_key(self, pool_key, request_context=None):
		# Copied from `urllib3` so that we continue to ensure that this will
		# call `_new_pool`
		with self.pools.lock:
			pool = self.pools.get(pool_key)
			if pool:
				return pool

			scheme = request_context['scheme']
			host = request_context['host']
			port = request_context['port']
			pool = self._new_pool(scheme, host, port, request_context=request_context)
			self.pools[pool_key] = pool
		return pool


# Override the lower-level `requests` adapter that invokes the `urllib3`
# PoolManager objects
class HTTPAdapter(requests.adapters.HTTPAdapter):
	def init_poolmanager(self, connections, maxsize, block=False, **pool_kwargs):
		# save these values for pickling (copied from `requests`)
		self._pool_connections = connections
		self._pool_maxsize = maxsize
		self._pool_block = block

		self.poolmanager = PoolManager(num_pools=connections, maxsize=maxsize,
		                               block=block, strict=True, **pool_kwargs)


# Override the highest-level `requests` Session object to accept the `family`
# parameter for any request and encode its value as part of the URL scheme
# when passing it down to the adapter
class Session(requests.Session):
	def __init__(self, *args, **kwargs):
		super().__init__(*args, **kwargs)
		self.family = socket.AF_UNSPEC

		# Additionally to mounting our variant of the usual HTTP and HTTPS
		# adapter, also mount it for some variants of the default schemes that
		# are limited to some specific address family only
		adapter = HTTPAdapter()
		for scheme in ("http", "https"):
			self.mount("{0}://".format(scheme), adapter)
			for name in AF2NAME.values():
				self.mount("{0}+{1}://".format(scheme, name), adapter)

	def request(self, method, url, *args, **kwargs):
		family = kwargs.pop("family", self.family)
		if family != socket.AF_UNSPEC:
			# Inject provided address family value as extension to scheme
			url = urllib.parse.urlparse(url)
			url = url._replace(scheme="{0}+{1}".format(url.scheme, AF2NAME[int(family)]))
			url = url.geturl()
		return super().request(method, url, *args, **kwargs)


session = Session


# Import other `requests` stuff to make the top-level API of this more compatible
from requests import (
	__title__, __description__, __url__, __version__, __build__, __author__,
	__author_email__, __license__, __copyright__, __cake__,
	
	exceptions, utils, packages, codes,
	Request, Response, PreparedRequest,
	RequestException, Timeout, URLRequired, TooManyRedirects, HTTPError,
	ConnectionError, FileModeWarning, ConnectTimeout, ReadTimeout
)


# Re-implement the top-level “session-less” API
def request(method, url, **kwargs):
	with Session() as session:
		return session.request(method=method, url=url, **kwargs)


def get(url, params=None, **kwargs):
	kwargs.setdefault('allow_redirects', True)
	return request('get', url, params=params, **kwargs)


def options(url, **kwargs):
	kwargs.setdefault('allow_redirects', True)
	return request('options', url, **kwargs)


def head(url, **kwargs):
	kwargs.setdefault('allow_redirects', False)
	return request('head', url, **kwargs)


def post(url, data=None, json=None, **kwargs):
	return request('post', url, data=data, json=json, **kwargs)


def put(url, data=None, **kwargs):
	return request('put', url, data=data, **kwargs)


def patch(url, data=None, **kwargs):
	return request('patch', url, data=data, **kwargs)


def delete(url, **kwargs):
	return request('delete', url, **kwargs)