#!/usr/bin/env python
"""
Implementation of the minusconf protocol. See http://code.google.com/p/minusconf/ for details.
Apache License 2.0, see the LICENSE file for details.

Most users will want a (Thread)Advertiser to advertise their server's location and a Seeker to find out the locations of servers and report them to a client program.

"""

import struct
import socket
import threading
import time

_PORT = 6376
_ADDRESS_4 = '239.45.99.98'
_ADDRESS_6 = 'ff08:0:0:6d69:6e75:7363:6f6e:6600'
_ADDRESSES = [_ADDRESS_4]
if socket.has_ipv6:
    _ADDRESSES.append(_ADDRESS_6)
_CHARSET = 'UTF-8'
VERSION='1.0'

# Compatibility functions
try:
    if bytes != str: # Python 3+
        _compat_bytes = lambda bytestr: bytes(bytestr, 'charmap')
    else: # 2.6+
        _compat_bytes = str
except NameError: # <2.6
    _compat_bytes = str

try:
    _compat_str = unicode
except NameError: # Python 3+
    _compat_str = str

_MAGIC = _compat_bytes('\xad\xc3\xe6\xe7')
_OPCODE_QUERY = _compat_bytes('\x01')
_OPCODE_ADVERTISEMENT = _compat_bytes('\x65')
_OPCODE_ERROR = _compat_bytes('\x6f')
_STRING_TERMINATOR = _compat_bytes('\x00')

_TTL = None
_MAX_PACKET_SIZE = 2048 # Biggest packet size this implementation will accept"""
_SEEKER_TIMEOUT = 2.0 # Timeout for seeks in s


class MinusconfError(Exception):
    def __init__(self, msg=''):
        super(MinusconfError, self).__init__()
        self.msg = msg

    def send(self, sock, to):
        _send_packet(sock, to, _OPCODE_ERROR, _encode_string(self.msg))


class _ImmutableStruct(object):
    """ Helper structure for immutable objects """

    def __setattr__(self, *args):
        raise TypeError("This structure is immutable")
    __delattr__ = __setattr__

    def __init__(self, **kwargs):
        for k,v in kwargs.items():
            super(_ImmutableStruct, self).__setattr__(k, v)

    def __eq__(self, other):
        return self.__dict__ == other.__dict__

    def __ne__(self, other):
        return self.__dict__ != other.__dict__

    def __lt__(self, other):
        return self.__dict__ < other.__dict__

    def __le__(self, other):
        return self.__dict__ <= other.__dict__

    def __gt__(self, other):
        return self.__dict__ > other.__dict__

    def __ge__(self, other):
        return self.__dict__ >= other.__dict__

    def __hash__(self):
        return hash(sum((hash(i) for i in self.__dict__.items())))


class _MinusconfImmutableStruct(_ImmutableStruct):
    def __init__(self, **kwargs):
        for v in kwargs.values():
            _check_val(v)

        super(_MinusconfImmutableStruct, self).__init__(**kwargs)


class Service(_MinusconfImmutableStruct):
    """ Helper structure for a service."""

    def __init__(self, stype, port, sname='', location=''):
        super(Service, self).__init__(stype=stype, port=_compat_str(port), sname=sname, location=location)

    def matches_query(self, stype, sname):
        return _string_match(stype, self.stype) and _string_match(sname, self.sname)

    def __str__(self):
        res = self.stype + ' service at '
        if self.sname != '': res += self.sname + ' '
        res += self.location + ':' + self.port

        return res

    def __repr__(self):
        return ('Service(' +
            repr(self.stype) + ', ' +
            repr(self.port) + ', ' +
            repr(self.sname) + ', ' +
            repr(self.location) + ')')


class ServiceAt(_MinusconfImmutableStruct):
    """ A service returned by an advertiser"""

    def __init__(self, aname, stype, sname, location, port, addr):
        super(ServiceAt, self).__init__(
            aname=aname, stype=stype, sname=sname, location=location, port=port, addr=addr
        )

    def matches_query_at(self, aname, stype, sname):
        return _string_match(stype, self.stype) and _string_match(sname, self.sname) and _string_match(aname, self.aname)

    @property
    def effective_location(self):
        return self.location if self.location != "" else self.addr

    def __str__(self):
        return (
            self.stype + ' service at ' +
            ((self.sname + ' ') if self.sname != '' else '') +
            self.location + ':' + self.port +
            ' (advertiser "' + self.aname + '" at ' + self.addr + ')'
            )

    def __repr__(self):
        return ('ServiceAt(' +
            repr(self.aname) + ', ' +
            repr(self.stype) + ', ' +
            repr(self.sname) + ', ' +
            repr(self.location) + ', ' +
            repr(self.port) + ', ' +
            repr(self.addr) + ')')


class Advertiser(object):
    """ Generic implementation of a -conf advertiser. You will probably want to use one of the subclasses.
    If ignore_unavailable is set, unsupported addresses (typically IPv6) are silently ignored
    """

    def __init__(self, services=[], aname=None, ignore_unavailable=True):
        super(Advertiser, self).__init__()

        self.services = services
        self.aname = aname if aname != None else socket.gethostname()
        self.port = _PORT
        self.addresses = _ADDRESSES
        self.ignore_unavailable = ignore_unavailable

    def _set_aname(self, aname):
        _check_val(aname)
        self._aname = aname
    aname = property(fget=lambda self:self._aname, fset=_set_aname)

    def run(self):
        self._init_advertiser()

        while True:
            rawdata,sender = self._sock.recvfrom(_MAX_PACKET_SIZE)
            self._handle_packet(rawdata, sender)

    def _init_advertiser(self):
        sock = _find_sock()
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, struct.pack('@I', 1))
        sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_LOOP, struct.pack('@I', 1))
        if sock.family == socket.AF_INET6:
            sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_LOOP, struct.pack('@I', 1))
        sock.bind(('', self.port))

        addrs = _resolve_addrs(self.addresses, None, self.ignore_unavailable, (sock.family,))

        for fam,to,orig_fam,orig_addr in addrs:
            try:
                _multicast_join_group(sock, orig_fam, orig_addr)
            except socket.error:
                if not self.ignore_unavailable:
                    raise

        self._sock = sock

    def _handle_packet(self, rawdata, sender):
        try:
            opcode, data = _parse_packet(rawdata)

            if opcode == _OPCODE_QUERY:
                self._handle_query(sender, data)
            elif opcode == _OPCODE_ERROR:
                pass # Explicitely prevent bouncing errors
            elif opcode == None:
                raise MinusconfError('Minusconf magic missing. See http://code.google.com/p/minusconf/source/browse/trunk/protocol.txt for details.')
            else:
                raise MinusconfError('Invalid or unsupported opcode ' + str(struct.unpack('!B', opcode)[0]))
        # Comment out for verbose error handling
        #except MinusconfError, mce:
            #mce.send(self._sock, sender)
        except MinusconfError:
            pass

    def services_matching(self, stype, sname):
        return filter(lambda svc: svc.matches_query(stype, sname), self.services)

    def _handle_query(self, sender, qrydata):
        qaname,p = _decode_string(qrydata, 0)
        qstype,p = _decode_string(qrydata, p)
        qsname,p = _decode_string(qrydata, p)

        if _string_match(qaname, self.aname):
            for svc in self.services_matching(qstype, qsname):
                rply = (
                    _encode_string(self.aname) +
                    _encode_string(svc.stype) +
                    _encode_string(svc.sname) +
                    _encode_string(svc.location) +
                    _encode_string(svc.port)
                    )

                _send_packet(self._sock, sender, _OPCODE_ADVERTISEMENT, rply)


class ConcurrentAdvertiser(Advertiser):
    # Subclasses must set _cav_started to an event

    def start_blocking(self):
        """ Start the advertiser in the background, but wait until it is ready """

        self._cav_started.clear()
        self.start()
        self._cav_started.wait()

    def _init_advertiser(self):
        try:
            super(ConcurrentAdvertiser, self)._init_advertiser()
        finally:
            self._cav_started.set()

    def wait_until_ready(self, timeout=None):
        self._cav_started.wait(timeout)

    def stop(self):
        raise NotImplementedError()

    def stop_blocking(self):
        raise NotImplementedError()


class ThreadAdvertiser(ConcurrentAdvertiser, threading.Thread):
    def __init__(self, services=[], aname=None, ignore_unavailable=True, daemon=True):
        ConcurrentAdvertiser.__init__(self, services, aname, ignore_unavailable)
        threading.Thread.__init__(self)

        self.setDaemon(daemon)

        self._cav_started = self._createEvent()
        self._ta_should_stop = self._createEvent()

    def run(self):
        self._ta_should_stop.clear()

        self._init_advertiser()

        while True:
            rawdata,sender = self._sock.recvfrom(_MAX_PACKET_SIZE)
            if self._ta_should_stop.is_set():
                break
            self._handle_packet(rawdata, sender)

    def stop(self):
        self._ta_should_stop.set()

    def stop_blocking(self):
        """ Stop the service and wait for it to be cleaned up. """
        self.stop() # The thread will be there, but will terminate upon the next message

    @staticmethod
    def _createEvent():
        res = threading.Event()

        if not hasattr(res, 'is_set'): # Python<2.6
            res.is_set = res.isSet

        return res

try:
    import multiprocessing

    class MultiprocessingAdvertiser(ConcurrentAdvertiser, multiprocessing.Process):
        """
        multiprocessing is only available for Python 2.6+.
        See http://code.google.com/p/python-multiprocessing/ for a backport.
        """
        def __init__(self, services=[], aname=None, ignore_unavailable=True, daemon=True):
            ConcurrentAdvertiser.__init__(self, services, aname, ignore_unavailable)
            multiprocessing.Process.__init__(self)

            self.daemon = daemon
            self._cav_started = multiprocessing.Event()

            self._mpa_manager = multiprocessing.Manager()
            self.services = self._mpa_manager.list(services)

        def stop(self):
            self.terminate()

        def stop_blocking(self):
            self.stop()
            self.join()
except ImportError:
    pass


class Seeker(threading.Thread):
    """ find_callback is called with (this_seeker,found_service_at)
    error_callback is called with (this seeker, sender, error message)
    """
    def __init__(self, stype='', aname='', sname='', timeout=_SEEKER_TIMEOUT, port=_PORT, addresses=_ADDRESSES, find_callback=None, error_callback=None, daemonized=True, ignore_senderrors=True):
        super(Seeker, self).__init__()

        self.timeout = timeout
        self.port = port
        self.addresses = addresses
        self.find_callback = find_callback
        self.error_callback = error_callback
        self.setDaemon(daemonized)
        self.ignore_senderrors = ignore_senderrors
        self.reset(stype, aname, sname)

    def reset(self, stype='', aname='', sname=''):
        self.stype = stype
        self.aname = aname
        self.sname = sname

    def _set_stype(self, stype):
        _check_val(stype)
        self._stype = stype
    stype = property(fget=lambda self:self._stype, fset=_set_stype)

    def _set_aname(self, aname):
        _check_val(aname)
        self._aname = aname
    aname = property(fget=lambda self:self._aname, fset=_set_aname)

    def _set_sname(self, sname):
        _check_val(sname)
        self._sname = sname
    sname = property(fget=lambda self:self._sname, fset=_set_sname)

    def run(self):
        self._init_seeker()

        if self._send_queries() > 0:
            self._read_replies()

    def run_forever(self):
        self.timeout = None
        self.run()

    def _init_seeker(self):
        self.results = set()

        self._sock = _find_sock()
        _multicast_configure_sender(self._sock, _TTL)

    def _send_queries(self):
        """ Sends queries to multiple addresses. Returns the number of successful queries. """

        res = 0

        addrs = _resolve_addrs(self.addresses, self.port, self.ignore_senderrors, [self._sock.family])
        for addr in addrs:
            try:
                self._send_query(addr[1])
                res += 1
            except:
                if not self.ignore_senderrors:
                    raise

        return res

    def _send_query(self, to):
        binqry = _encode_string(self.aname)
        binqry += _encode_string(self.stype)
        binqry += _encode_string(self.sname)

        _send_packet(self._sock, to, _OPCODE_QUERY, binqry)

    def _read_replies(self):
        if self.timeout == None:
            self._sock.settimeout(None)
        else:
            starttime = time.time()

        while True:
            if self.timeout != None:
                timeout = self.timeout - (time.time() - starttime)
                if timeout < 0:
                    break

                self._sock.settimeout(timeout)

            try:
                rawdata,sender = self._sock.recvfrom(_MAX_PACKET_SIZE)
            except socket.timeout:
                break

            self._handle_packet(rawdata, sender)

    def _handle_packet(self, rawdata, sender):
        try:
            opcode,data = _parse_packet(rawdata)

            if opcode == _OPCODE_ADVERTISEMENT:
                self._handle_advertisement(data, sender)
            elif opcode == _OPCODE_ERROR:
                try:
                    error_str = _decode_string(data, 0)[0]
                except:
                    error_str = '[Error when trying to read error message ' + repr(data) + ']'

                if self.error_callback != None:
                    self.error_callback(self, sender, error_str)
            else: # Invalid opcode
                pass
        except MinusconfError: # Invalid packet
            pass

    def _handle_advertisement(self, bindata, sender):
        aname,p = _decode_string(bindata, 0)
        stype,p = _decode_string(bindata, p)
        sname,p = _decode_string(bindata, p)
        location,p = _decode_string(bindata, p)
        port,p = _decode_string(bindata, p)

        if stype == '': # servicetype must be non-empty
            return

        svca = ServiceAt(aname, stype, sname, location, port, sender[0])
        if svca.matches_query_at(self.aname, self.stype, self.sname):
            self._found_result(svca)

    def _found_result(self, result):
        if not (result in self.results):
            self.results.add(result)
            if self.find_callback != None:
                self.find_callback(self, result)

def _send_packet(sock, to, opcode, data):
    sock.sendto(_MAGIC + opcode + data, 0, to)

def _parse_packet(rawdata):
    """ Returns a tuple (opcode, minusconf-data). opcode is None if this isn't a -conf packet."""

    if (len(rawdata) < len(_MAGIC) + 1) or (_MAGIC != rawdata[:len(_MAGIC)]):
        # Wrong protocol
        return (None, None)

    opcode = rawdata[len(_MAGIC):len(_MAGIC)+1]
    payload = rawdata[len(_MAGIC)+1:]

    return (opcode, payload)

def _check_val(val):
    """ Checks whether a minusconf value contains any NUL bytes. """
    try:
        if val.find('\x00') >= 0:
            raise ValueError(repr(val) + ' contains a NUL byte')
    except AttributeError: # Not a string or compatible
        pass

def _encode_string(val):
    return val.encode(_CHARSET) + _STRING_TERMINATOR

def _decode_string(buf, pos):
    """ Decodes a string in the buffer buf, starting at position pos.
    Returns a tuple of the read string and the next byte to read.
    """
    for i in range(pos, len(buf)):
        if buf[i:i+1] == _compat_bytes('\x00'):
            try:
                return (buf[pos:i].decode(_CHARSET), i+1)
            # Uncomment the following two lines for detailled information
            #except UnicodeDecodeError as ude:
            #   raise MinusconfError(str(ude))
            except UnicodeDecodeError:
                raise MinusconfError('Not a valid ' + _CHARSET + ' string: ' + repr(buf[pos:i]))

    raise MinusconfError("Premature end of string (Forgot trailing \\0?), buf=" + repr(buf))

def _string_match(query, value):
    return query == "" or query == value

def _multicast_configure_sender(sock, ttl=None):
    if ttl != None:
        ttl_bin = struct.pack('@I', ttl)

        sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, ttl_bin)
        if sock.family == socket.AF_INET6:
            sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_HOPS, ttl_bin)

def _multicast_join_group(sock, family, addr):
    group_bin = _inet_pton(family, addr)
    if family == socket.AF_INET: # IPv4
        mreq = group_bin + struct.pack('=I', socket.INADDR_ANY)
        sock.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, mreq)
    elif family == socket.AF_INET6: # IPv6
        mreq = group_bin + struct.pack('@I', 0)
        sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_JOIN_GROUP, mreq)
    else:
        raise ValueError('Unsupported protocol family ' + family)

def _resolve_addrs(straddrs, port, ignore_unavailable=False, protocols=[socket.AF_INET, socket.AF_INET6]):
    """ Returns a list of tuples of (family, to, original_addr_family, original_addr).

    If ignore_unavailable is set, addresses for unavailable protocols are ignored.
    protocols determines the protocol family indices supported by the socket in use. """

    res = []
    for sa in straddrs:
        try:
            ais = socket.getaddrinfo(sa, port)
            for ai in ais:
                if ai[0] in protocols:
                    res.append((ai[0], ai[4], ai[0], ai[4][0]))
                    break
            else:
                # Try to convert from IPv4 to IPv6
                ai = ais[0]
                if ai[0] == socket.AF_INET and socket.AF_INET6 in protocols:
                    to = socket.getaddrinfo('::ffff:' + ai[4][0], port, socket.AF_INET6)[0][4]
                    res.append((socket.AF_INET6, to, ai[0], ai[4][0]))
        except socket.gaierror:
            if not ignore_unavailable:
                raise

    return res

def _find_sock():
    """ Create a UDP socket """
    if socket.has_ipv6:
        try:
            return socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
        except socket.gaierror:
            pass # Platform lied about IPv6 support
    return socket.socket(socket.AF_INET, socket.SOCK_DGRAM)

def _main():
    """ CLI interface """
    import sys

    if len(sys.argv) < 2:
        _usage('Expected at least one parameter!')

    sc = sys.argv[1]
    options = sys.argv[2:]
    if sc == 'a' or sc == 'advertise':
        if len(options) > 5 or len(options) < 2:
            _usage()

        stype,port = options[:2]
        advertisername = options[2] if len(options) > 2 else None
        sname = options[3] if len(options) > 3 else ''
        slocation = options[4] if len(options) > 4 else ''

        service = Service(stype, port, sname, slocation)
        advertiser = Advertiser([service], advertisername)
        advertiser.run()
    elif sc == 's' or sc == 'seek':
        if len(options) > 4:
            _usage()

        aname = options[0] if len(options) > 0 else ''
        stype = options[1] if len(options) > 1 else ''
        sname = options[2] if len(options) > 2 else ''

        se = Seeker(aname, stype, sname, find_callback=_print_result, error_callback=_print_error)
        se.run()
    else:
        _usage('Unknown subcommand "' + sys.argv[0] + '"')

def _print_result(seeker, svca):
    print ("Found " + str(svca))

def _print_error(seeker, opposite, error_str):
    import sys
    sys.stderr.write("Error from " + str(opposite) + ": " + error_str + "\n")

def _usage(note=None, and_exit=True):
    import sys

    if note != None:
        print("Error: " + note + "\n")

    print("Usage: " + sys.argv[0] + " subcommand options...")
    print("\ta[dvertise] servicetype port [advertisername [servicename [location]]]")
    print("\ts[eek]      [servicetype [advertisername [servicename]]]")
    print('Use "" for default/any value.')
    print("Examples:")
    print("\t" + sys.argv[0] + " advertise http 80 fastmachine Apache")
    print("\t" + sys.argv[0] + ' seek http "" Apache')

    if and_exit:
        sys.exit(0)

def _compat_inet_pton(family, addr):
    """ socket.inet_pton for platforms that don't have it """

    if family == socket.AF_INET:
        # inet_aton accepts some strange forms, so we use our own
        res = _compat_bytes('')
        parts = addr.split('.')
        if len(parts) != 4:
            raise ValueError('Expected 4 dot-separated numbers')

        for part in parts:
            intval = int(part, 10)
            if intval < 0 or intval > 0xff:
                raise ValueError("Invalid integer value in IPv4 address: " + str(intval))

            res = res + struct.pack('!B', intval)

        return res
    elif family == socket.AF_INET6:
        wordcount = 8
        res = _compat_bytes('')

        # IPv4 embedded?
        dotpos = addr.find('.')
        if dotpos >= 0:
            v4start = addr.rfind(':', 0, dotpos)
            if v4start == -1:
                raise ValueException("Missing colons in an IPv6 address")
            wordcount = 6
            res = socket.inet_aton(addr[v4start+1:])
            addr = addr[:v4start] + '!' # We leave a marker that the address is not finished

        # Compact version?
        compact_pos = addr.find('::')
        if compact_pos >= 0:
            if compact_pos == 0:
                addr = '0' + addr
                compact_pos += 1
            if compact_pos == len(addr)-len('::'):
                addr = addr + '0'

            addr = (addr[:compact_pos] + ':' +
                ('0:' * (wordcount - (addr.count(':') - '::'.count(':')) - 2))
                + addr[compact_pos + len('::'):])

        # Remove any dots we left
        if addr.endswith('!'):
            addr = addr[:-len('!')]

        words = addr.split(':')
        if len(words) != wordcount:
            raise ValueError('Invalid number of IPv6 hextets, expected ' + str(wordcount) + ', got ' + str(len(words)))
        for w in reversed(words):
            # 0x and negative is not valid here, but accepted by int(,16)
            if 'x' in w or '-' in w:
                raise ValueError("Invalid character in IPv6 address")

            intval = int(w, 16)
            if intval > 0xffff:
                raise ValueError("IPv6 address componenent too big")
            res = struct.pack('!H', intval) + res

        return res

    else:
        raise ValueError("Unknown protocol family " + family)

# Cover for socket_pton inavailability on some systems (non-IPv6 or Windows)
try:
    import ipaddr

    if hasattr(ipaddr.IPv4, 'packed'):
        def _inet_pton(family, addr):
            if family == socket.AF_INET:
                return ipaddr.IPv4(addr).packed
            elif family == socket.AF_INET6:
                return ipaddr.IPv6(addr).packed
            else:
                raise ValueError("Unknown protocol family " + family)
except:
    pass

if not '_inet_pton' in dir():
    if hasattr(socket, 'inet_pton'):
        _inet_pton = socket.inet_pton
    else:
        _inet_pton = _compat_inet_pton

if __name__ == '__main__':
    _main()