#!/usr/bin/env python
# coding:utf-8

__version__ = '1.0'

import sys
import os
import sysconfig

sys.path += [os.path.abspath(os.path.join(__file__, '../packages.egg/%s' % x)) for x in ('noarch', sysconfig.get_platform().split('-')[0])]

import gevent
import gevent.server
import gevent.timeout
import gevent.monkey
gevent.monkey.patch_all(subprocess=True)

import re
import time
import logging
import heapq
import socket
import select
import struct
import errno
import thread
import dnslib
import Queue
import pygeoip


is_local_addr = re.compile(r'(?i)(?:[0-9a-f:]+0:5efe:)?(?:127(?:\.\d+){3}|10(?:\.\d+){3}|192\.168(?:\.\d+){2}|172\.(?:1[6-9]|2\d|3[01])(?:\.\d+){2})').match


def get_dnsserver_list():
    if os.name == 'nt':
        import ctypes, ctypes.wintypes, struct, socket
        DNS_CONFIG_DNS_SERVER_LIST = 6
        buf = ctypes.create_string_buffer(2048)
        ctypes.windll.dnsapi.DnsQueryConfig(DNS_CONFIG_DNS_SERVER_LIST, 0, None, None, ctypes.byref(buf), ctypes.byref(ctypes.wintypes.DWORD(len(buf))))
        ipcount = struct.unpack('I', buf[0:4])[0]
        iplist = [socket.inet_ntoa(buf[i:i+4]) for i in xrange(4, ipcount*4+4, 4)]
        return iplist
    elif os.path.isfile('/etc/resolv.conf'):
        with open('/etc/resolv.conf', 'rb') as fp:
            return re.findall(r'(?m)^nameserver\s+(\S+)', fp.read())
    else:
        logging.warning("get_dnsserver_list failed: unsupport platform '%s-%s'", sys.platform, os.name)
        return []


def parse_hostport(host, default_port=80):
    m = re.match(r'(.+)[#](\d+)$', host)
    if m:
        return m.group(1).strip('[]'), int(m.group(2))
    else:
        return host.strip('[]'), default_port


class ExpireCache(object):
    """ A dictionary-like object, supporting expire semantics."""
    def __init__(self, max_size=1024):
        self.__maxsize = max_size
        self.__values = {}
        self.__expire_times = {}
        self.__expire_heap = []

    def size(self):
        return len(self.__values)

    def clear(self):
        self.__values.clear()
        self.__expire_times.clear()
        del self.__expire_heap[:]

    def exists(self, key):
        return key in self.__values

    def set(self, key, value, expire):
        try:
            et = self.__expire_times[key]
            pos = self.__expire_heap.index((et, key))
            del self.__expire_heap[pos]
            if pos < len(self.__expire_heap):
                heapq._siftup(self.__expire_heap, pos)
        except KeyError:
            pass
        et = int(time.time() + expire)
        self.__expire_times[key] = et
        heapq.heappush(self.__expire_heap, (et, key))
        self.__values[key] = value
        self.cleanup()

    def get(self, key):
        et = self.__expire_times[key]
        if et < time.time():
            self.cleanup()
            raise KeyError(key)
        return self.__values[key]

    def delete(self, key):
        et = self.__expire_times.pop(key)
        pos = self.__expire_heap.index((et, key))
        del self.__expire_heap[pos]
        if pos < len(self.__expire_heap):
            heapq._siftup(self.__expire_heap, pos)
        del self.__values[key]

    def cleanup(self):
        t = int(time.time())
        eh = self.__expire_heap
        ets = self.__expire_times
        v = self.__values
        size = self.__maxsize
        heappop = heapq.heappop
        #Delete expired, ticky
        while eh and eh[0][0] <= t or len(v) > size:
            _, key = heappop(eh)
            del v[key], ets[key]


def dnslib_resolve_over_udp(query, dnsservers, timeout, **kwargs):
    """
    http://gfwrev.blogspot.com/2009/11/gfwdns.html
    http://zh.wikipedia.org/wiki/%E5%9F%9F%E5%90%8D%E6%9C%8D%E5%8A%A1%E5%99%A8%E7%BC%93%E5%AD%98%E6%B1%A1%E6%9F%93
    http://support.microsoft.com/kb/241352
    https://gist.github.com/klzgrad/f124065c0616022b65e5
    """
    if not isinstance(query, (basestring, dnslib.DNSRecord)):
        raise TypeError('query argument requires string/DNSRecord')
    blacklist = kwargs.get('blacklist', ())
    turstservers = kwargs.get('turstservers', ())
    dns_v4_servers = [x for x in dnsservers if ':' not in x]
    dns_v6_servers = [x for x in dnsservers if ':' in x]
    sock_v4 = sock_v6 = None
    socks = []
    if dns_v4_servers:
        sock_v4 = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        socks.append(sock_v4)
    if dns_v6_servers:
        sock_v6 = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
        socks.append(sock_v6)
    timeout_at = time.time() + timeout
    try:
        for _ in xrange(4):
            try:
                for dnsserver in dns_v4_servers:
                    if isinstance(query, basestring):
                        if dnsserver in ('8.8.8.8', '8.8.4.4'):
                            query = '.'.join(x[:-1] + x[-1].upper() for x in query.split('.')).title()
                        query = dnslib.DNSRecord(q=dnslib.DNSQuestion(query))
                    query_data = query.pack()
                    if query.q.qtype == 1 and dnsserver in ('8.8.8.8', '8.8.4.4'):
                        query_data = query_data[:-5] + '\xc0\x04' + query_data[-4:]
                    sock_v4.sendto(query_data, parse_hostport(dnsserver, 53))
                for dnsserver in dns_v6_servers:
                    if isinstance(query, basestring):
                        query = dnslib.DNSRecord(q=dnslib.DNSQuestion(query, qtype=dnslib.QTYPE.AAAA))
                    query_data = query.pack()
                    sock_v6.sendto(query_data, parse_hostport(dnsserver, 53))
                while time.time() < timeout_at:
                    ins, _, _ = select.select(socks, [], [], 0.1)
                    for sock in ins:
                        reply_data, reply_address = sock.recvfrom(512)
                        reply_server = reply_address[0]
                        record = dnslib.DNSRecord.parse(reply_data)
                        iplist = [str(x.rdata) for x in record.rr if x.rtype in (1, 28, 255)]
                        if any(x in blacklist for x in iplist):
                            logging.warning('query=%r dnsservers=%r record bad iplist=%r', query, dnsservers, iplist)
                        elif record.header.rcode and not iplist and reply_server in turstservers:
                            logging.info('query=%r trust reply_server=%r record rcode=%s', query, reply_server, record.header.rcode)
                            return record
                        elif iplist:
                            logging.debug('query=%r reply_server=%r record iplist=%s', query, reply_server, iplist)
                            return record
                        else:
                            logging.debug('query=%r reply_server=%r record null iplist=%s', query, reply_server, iplist)
                            continue
            except socket.error as e:
                logging.warning('handle dns query=%s socket: %r', query, e)
        raise socket.gaierror(11004, 'getaddrinfo %r from %r failed' % (query, dnsservers))
    finally:
        for sock in socks:
            sock.close()


def dnslib_resolve_over_tcp(query, dnsservers, timeout, **kwargs):
    """dns query over tcp"""
    if not isinstance(query, (basestring, dnslib.DNSRecord)):
        raise TypeError('query argument requires string/DNSRecord')
    blacklist = kwargs.get('blacklist', ())
    def do_resolve(query, dnsserver, timeout, queobj):
        if isinstance(query, basestring):
            qtype = dnslib.QTYPE.AAAA if ':' in dnsserver else dnslib.QTYPE.A
            query = dnslib.DNSRecord(q=dnslib.DNSQuestion(query, qtype=qtype))
        query_data = query.pack()
        sock_family = socket.AF_INET6 if ':' in dnsserver else socket.AF_INET
        sock = socket.socket(sock_family)
        rfile = None
        try:
            sock.settimeout(timeout or None)
            sock.connect(parse_hostport(dnsserver, 53))
            sock.send(struct.pack('>h', len(query_data)) + query_data)
            rfile = sock.makefile('r', 1024)
            reply_data_length = rfile.read(2)
            if len(reply_data_length) < 2:
                raise socket.gaierror(11004, 'getaddrinfo %r from %r failed' % (query, dnsserver))
            reply_data = rfile.read(struct.unpack('>h', reply_data_length)[0])
            record = dnslib.DNSRecord.parse(reply_data)
            iplist = [str(x.rdata) for x in record.rr if x.rtype in (1, 28, 255)]
            if any(x in blacklist for x in iplist):
                logging.debug('query=%r dnsserver=%r record bad iplist=%r', query, dnsserver, iplist)
                raise socket.gaierror(11004, 'getaddrinfo %r from %r failed' % (query, dnsserver))
            else:
                logging.debug('query=%r dnsserver=%r record iplist=%s', query, dnsserver, iplist)
                queobj.put(record)
        except socket.error as e:
            logging.debug('query=%r dnsserver=%r failed %r', query, dnsserver, e)
            queobj.put(e)
        finally:
            if rfile:
                rfile.close()
            sock.close()
    queobj = Queue.Queue()
    for dnsserver in dnsservers:
        thread.start_new_thread(do_resolve, (query, dnsserver, timeout, queobj))
    for i in range(len(dnsservers)):
        try:
            result = queobj.get(timeout)
        except Queue.Empty:
            raise socket.gaierror(11004, 'getaddrinfo %r from %r failed' % (query, dnsservers))
        if result and not isinstance(result, Exception):
            return result
        elif i == len(dnsservers) - 1:
            logging.warning('dnslib_resolve_over_tcp %r with %s return %r', query, dnsservers, result)
    raise socket.gaierror(11004, 'getaddrinfo %r from %r failed' % (query, dnsservers))


class DNSServer(gevent.server.DatagramServer):
    """DNS Proxy based on gevent/dnslib"""

    def __init__(self, *args, **kwargs):
        dns_blacklist = kwargs.pop('dns_blacklist')
        dns_servers = kwargs.pop('dns_servers')
        dns_tcpover = kwargs.pop('dns_tcpover', [])
        dns_timeout = kwargs.pop('dns_timeout', 2)
        super(self.__class__, self).__init__(*args, **kwargs)
        self.dns_servers = list(dns_servers)
        self.dns_tcpover = tuple(dns_tcpover)
        self.dns_intranet_servers = [x for x in self.dns_servers if is_local_addr(x)]
        self.dns_blacklist = set(dns_blacklist)
        self.dns_timeout = int(dns_timeout)
        self.dns_cache = ExpireCache(max_size=65536)
        self.dns_trust_servers = set(['8.8.8.8', '8.8.4.4', '2001:4860:4860::8888', '2001:4860:4860::8844'])
        for dirname in ('.', '/usr/share/GeoIP/', '/usr/local/share/GeoIP/'):
            filename = os.path.join(dirname, 'GeoIP.dat')
            if os.path.isfile(filename):
                geoip = pygeoip.GeoIP(filename)
                for dnsserver in self.dns_servers:
                    if ':' not in dnsserver and geoip.country_name_by_addr(parse_hostport(dnsserver, 53)[0]) not in ('China',):
                        self.dns_trust_servers.add(dnsserver)
                break

    def do_read(self):
        try:
            return gevent.server.DatagramServer.do_read(self)
        except socket.error as e:
            if e[0] not in (errno.ECONNABORTED, errno.ECONNRESET, errno.EPIPE):
                raise

    def get_reply_record(self, data):
        request = dnslib.DNSRecord.parse(data)
        qname = str(request.q.qname).lower()
        qtype = request.q.qtype
        dnsservers = self.dns_servers
        if qname.endswith('.in-addr.arpa'):
            ipaddr = '.'.join(reversed(qname[:-13].split('.')))
            record = dnslib.DNSRecord(header=dnslib.DNSHeader(id=request.header.id, qr=1,aa=1,ra=1), a=dnslib.RR(qname, rdata=dnslib.A(ipaddr)))
            return record
        if 'USERDNSDOMAIN' in os.environ:
            user_dnsdomain = '.' + os.environ['USERDNSDOMAIN'].lower()
            if qname.endswith(user_dnsdomain):
                qname = qname[:-len(user_dnsdomain)]
                if '.' not in qname:
                    if not self.dns_intranet_servers:
                        logging.warning('qname=%r is a plain hostname, need intranet dns server!!!', qname)
                        return dnslib.DNSRecord(header=dnslib.DNSHeader(id=request.header.id, rcode=3))
                    qname += user_dnsdomain
                    dnsservers = self.dns_intranet_servers
        try:
            return self.dns_cache.get((qname, qtype))
        except KeyError:
            pass
        try:
            dns_resolve = dnslib_resolve_over_tcp if qname.endswith(self.dns_tcpover) else dnslib_resolve_over_udp
            kwargs = {'blacklist': self.dns_blacklist, 'turstservers': self.dns_trust_servers}
            record = dns_resolve(request, dnsservers, self.dns_timeout, **kwargs)
            ttl = max(x.ttl for x in record.rr) if record.rr else 600
            self.dns_cache.set((qname, qtype), record, ttl * 2)
            return record
        except socket.gaierror as e:
            logging.warning('resolve %r failed: %r', qname, e)
            return dnslib.DNSRecord(header=dnslib.DNSHeader(id=request.header.id, rcode=3))

    def handle(self, data, address):
        logging.debug('receive from %r data=%r', address, data)
        record = self.get_reply_record(data)
        return self.sendto(data[:2] + record.pack()[2:], address)


def test():
    logging.basicConfig(level=logging.INFO, format='%(levelname)s - %(asctime)s %(message)s', datefmt='[%b %d %H:%M:%S]')
    dns_servers = '8.8.8.8|8.8.4.4|168.95.1.1|168.95.192.1|223.5.5.5|223.6.6.6|114.114.114.114|114.114.115.115'.split('|')
    dns_blacklist = '1.1.1.1|255.255.255.255|74.125.127.102|74.125.155.102|74.125.39.102|74.125.39.113|209.85.229.138|4.36.66.178|8.7.198.45|37.61.54.158|46.82.174.68|59.24.3.173|64.33.88.161|64.33.99.47|64.66.163.251|65.104.202.252|65.160.219.113|66.45.252.237|72.14.205.104|72.14.205.99|78.16.49.15|93.46.8.89|128.121.126.139|159.106.121.75|169.132.13.103|192.67.198.6|202.106.1.2|202.181.7.85|203.161.230.171|203.98.7.65|207.12.88.98|208.56.31.43|209.145.54.50|209.220.30.174|209.36.73.33|211.94.66.147|213.169.251.35|216.221.188.182|216.234.179.13|243.185.187.3|243.185.187.39|23.89.5.60|37.208.111.120|49.2.123.56|54.76.135.1|77.4.7.92|118.5.49.6|188.5.4.96|189.163.17.5|197.4.4.12|249.129.46.48|253.157.14.165|183.207.229.|183.207.232.'.split('|')
    dns_tcpover = ['.youtube.com', '.googlevideo.com']
    logging.info('serving at port 53...')
    DNSServer(('', 53), dns_servers=dns_servers, dns_blacklist=dns_blacklist, dns_tcpover=dns_tcpover).serve_forever()


if __name__ == '__main__':
    test()