import argparse import asyncio import logging import os import ssl import struct import time from aioice import stun from aioice.ice import get_host_addresses from aioice.turn import ( UDP_TRANSPORT, TurnStreamMixin, is_channel_data, make_integrity_key, ) from aioice.utils import random_string logger = logging.getLogger("turn") CHANNEL_RANGE = range(0x4000, 0x7FFF) ROOT = os.path.dirname(__file__) CERT_FILE = os.path.join(ROOT, "turnserver.crt") KEY_FILE = os.path.join(ROOT, "turnserver.key") def create_self_signed_cert(name="localhost"): from OpenSSL import crypto # create key pair key = crypto.PKey() key.generate_key(crypto.TYPE_RSA, 2048) # create self-signed certificate cert = crypto.X509() cert.get_subject().CN = name cert.set_serial_number(1000) cert.gmtime_adj_notBefore(0) cert.gmtime_adj_notAfter(10 * 365 * 86400) cert.set_issuer(cert.get_subject()) cert.set_pubkey(key) cert.sign(key, "sha1") with open(CERT_FILE, "wb") as fp: fp.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert)) with open(KEY_FILE, "wb") as fp: fp.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, key)) class Allocation(asyncio.DatagramProtocol): def __init__(self, client_address, client_protocol, expiry, username): self.channel_to_peer = {} self.peer_to_channel = {} self.client_address = client_address self.client_protocol = client_protocol self.expiry = expiry self.username = username def connection_made(self, transport): self.relayed_address = transport.get_extra_info("sockname") self.transport = transport def datagram_received(self, data, addr): """ Relay data from peer to client. """ channel = self.peer_to_channel.get(addr) if channel: self.client_protocol._send( struct.pack("!HH", channel, len(data)) + data, self.client_address ) class TurnServerMixin: def __init__(self, server): self.server = server def connection_made(self, transport): self.transport = transport def datagram_received(self, data, addr): # demultiplex channel data if len(data) >= 4 and is_channel_data(data): channel, length = struct.unpack("!HH", data[0:4]) allocation = self.server.allocations.get((self, addr)) if len(data) >= length + 4 and allocation: peer_address = allocation.channel_to_peer.get(channel) if peer_address: payload = data[4 : 4 + length] allocation.transport.sendto(payload, peer_address) return try: message = stun.parse_message(data) except ValueError: return logger.debug("< %s %s", addr, message) assert message.message_class == stun.Class.REQUEST if message.message_method == stun.Method.BINDING: response = self.handle_binding(message, addr) self.send_stun(response, addr) return if "USERNAME" not in message.attributes: response = self.error_response(message, 401, "Unauthorized") response.attributes["NONCE"] = random_string(16).encode("ascii") response.attributes["REALM"] = self.server.realm self.send_stun(response, addr) return # check credentials username = message.attributes["USERNAME"] password = self.server.users[username] integrity_key = make_integrity_key(username, self.server.realm, password) try: stun.parse_message(data, integrity_key=integrity_key) except ValueError: return if message.message_method == stun.Method.ALLOCATE: asyncio.ensure_future(self.handle_allocate(message, addr, integrity_key)) return elif message.message_method == stun.Method.REFRESH: response = self.handle_refresh(message, addr) elif message.message_method == stun.Method.CHANNEL_BIND: response = self.handle_channel_bind(message, addr) else: response = self.error_response( message, 400, "Unsupported STUN request method" ) response.add_message_integrity(integrity_key) response.add_fingerprint() self.send_stun(response, addr) async def handle_allocate(self, message, addr, integrity_key): key = (self, addr) if key in self.server.allocations: response = self.error_response(message, 437, "Allocation already exists") elif "REQUESTED-TRANSPORT" not in message.attributes: response = self.error_response( message, 400, "Missing REQUESTED-TRANSPORT attribute" ) elif message.attributes["REQUESTED-TRANSPORT"] != UDP_TRANSPORT: response = self.error_response( message, 442, "Unsupported transport protocol" ) else: lifetime = message.attributes.get("LIFETIME", self.server.default_lifetime) lifetime = min(lifetime, self.server.maximum_lifetime) # create allocation loop = asyncio.get_event_loop() _, allocation = await loop.create_datagram_endpoint( lambda: Allocation( client_address=addr, client_protocol=self, expiry=time.time() + lifetime, username=message.attributes["USERNAME"], ), local_addr=("127.0.0.1", 0), ) self.server.allocations[key] = allocation logger.info("Allocation created %s", allocation.relayed_address) # build response response = stun.Message( message_method=message.message_method, message_class=stun.Class.RESPONSE, transaction_id=message.transaction_id, ) response.attributes["LIFETIME"] = lifetime response.attributes["XOR-MAPPED-ADDRESS"] = addr response.attributes["XOR-RELAYED-ADDRESS"] = allocation.relayed_address # send response response.add_message_integrity(integrity_key) response.add_fingerprint() self.send_stun(response, addr) def handle_binding(self, message, addr): response = stun.Message( message_method=message.message_method, message_class=stun.Class.RESPONSE, transaction_id=message.transaction_id, ) response.attributes["XOR-MAPPED-ADDRESS"] = addr return response def handle_channel_bind(self, message, addr): try: key = (self, addr) allocation = self.server.allocations[key] except KeyError: return self.error_response(message, 437, "Allocation does not exist") if message.attributes["USERNAME"] != allocation.username: return self.error_response(message, 441, "Wrong credentials") for attr in ["CHANNEL-NUMBER", "XOR-PEER-ADDRESS"]: if attr not in message.attributes: return self.error_response(message, 400, "Missing %s attribute" % attr) channel = message.attributes["CHANNEL-NUMBER"] peer_address = message.attributes["XOR-PEER-ADDRESS"] if channel not in CHANNEL_RANGE: return self.error_response( message, 400, "Channel number is outside valid range" ) if allocation.channel_to_peer.get(channel) not in [None, peer_address]: return self.error_response( message, 400, "Channel is already bound to another peer" ) if allocation.peer_to_channel.get(peer_address) not in [None, channel]: return self.error_response( message, 400, "Peer is already bound to another channel" ) # register channel allocation.channel_to_peer[channel] = peer_address allocation.peer_to_channel[peer_address] = channel # build response response = stun.Message( message_method=message.message_method, message_class=stun.Class.RESPONSE, transaction_id=message.transaction_id, ) return response def handle_refresh(self, message, addr): try: key = (self, addr) allocation = self.server.allocations[key] except KeyError: return self.error_response(message, 437, "Allocation does not exist") if message.attributes["USERNAME"] != allocation.username: return self.error_response(message, 441, "Wrong credentials") if "LIFETIME" not in message.attributes: return self.error_response(message, 400, "Missing LIFETIME attribute") # refresh allocation lifetime = min(message.attributes["LIFETIME"], self.server.maximum_lifetime) if lifetime: logger.info("Allocation refreshed %s", allocation.relayed_address) allocation.expiry = time.time() + lifetime else: logger.info("Allocation deleted %s", allocation.relayed_address) del self.server.allocations[key] # build response response = stun.Message( message_method=message.message_method, message_class=stun.Class.RESPONSE, transaction_id=message.transaction_id, ) response.attributes["LIFETIME"] = lifetime return response def error_response(self, request, code, message): """ Build an error response for the given request. """ response = stun.Message( message_method=request.message_method, message_class=stun.Class.ERROR, transaction_id=request.transaction_id, ) response.attributes["ERROR-CODE"] = (code, message) return response def send_stun(self, message, addr): logger.debug("> %s %s", addr, message) self._send(bytes(message), addr) class TurnServerTcpProtocol(TurnServerMixin, TurnStreamMixin, asyncio.Protocol): def _send(self, data, addr): self.transport.write(data) class TurnServerUdpProtocol(TurnServerMixin, asyncio.DatagramProtocol): def _send(self, data, addr): self.transport.sendto(data, addr) class TurnServer: """ STUN / TURN server. """ def __init__(self, realm="test", users={}): self.allocations = {} self.default_lifetime = 600 self.maximum_lifetime = 3600 self.realm = realm self.users = users self._expire_handle = None async def close(self): # start expiry loop self._expire_handle.cancel() self.tcp_server.close() self.udp_server.transport.close() await self.tcp_server.wait_closed() async def listen(self, port=0, tls_port=0): loop = asyncio.get_event_loop() hostaddr = get_host_addresses(use_ipv4=True, use_ipv6=False)[0] # listen for TCP self.tcp_server = await loop.create_server( lambda: TurnServerTcpProtocol(server=self), host=hostaddr, port=port ) self.tcp_address = self.tcp_server.sockets[0].getsockname() logger.info("Listening for TCP on %s", self.tcp_address) # listen for UDP transport, self.udp_server = await loop.create_datagram_endpoint( lambda: TurnServerUdpProtocol(server=self), local_addr=(hostaddr, port) ) self.udp_address = transport.get_extra_info("sockname") logger.info("Listening for UDP on %s", self.udp_address) # listen for TLS ssl_context = ssl.SSLContext() ssl_context.load_cert_chain(CERT_FILE, KEY_FILE) self.tls_server = await loop.create_server( lambda: TurnServerTcpProtocol(server=self), host=hostaddr, port=tls_port, ssl=ssl_context, ) self.tls_address = self.tls_server.sockets[0].getsockname() logger.info("Listening for TLS on %s", self.tls_address) # start expiry loop self._expire_handle = asyncio.ensure_future(self._expire_allocations()) async def _expire_allocations(self): while True: now = time.time() for key, allocation in self.allocations.items(): if allocation.expiry < now: logger.info("Allocation expired %s", allocation.relayed_address) del self.allocations[key] await asyncio.sleep(1) if __name__ == "__main__": parser = argparse.ArgumentParser(description="STUN / TURN server") parser.add_argument("--verbose", "-v", action="count") args = parser.parse_args() if args.verbose: logging.basicConfig(level=logging.DEBUG) srv = TurnServer(realm="test", users={"foo": "bar"}) loop = asyncio.get_event_loop() loop.run_until_complete(srv.listen(port=3478, tls_port=5349)) loop.run_forever()