from nintendo.common import socket, scheduler from OpenSSL import crypto import tempfile import ssl import logging logger = logging.getLogger(__name__) TYPE_DER = 0 TYPE_PEM = 1 VERSION_TLS = 0 VERSION_TLS11 = 1 VERSION_TLS12 = 2 TypeMap = { TYPE_DER: crypto.FILETYPE_ASN1, TYPE_PEM: crypto.FILETYPE_PEM } VersionMap = { VERSION_TLS: ssl.PROTOCOL_TLS, VERSION_TLS11: ssl.PROTOCOL_TLSv1_1, VERSION_TLS12: ssl.PROTOCOL_TLSv1_2 } class SSLCertificate: def __init__(self, obj): self.obj = obj def public_key(self): pkey = self.obj.get_pubkey() rsakey = pkey.to_cryptography_key() return rsakey.public_numbers() def encode(self, format): return crypto.dump_certificate(TypeMap[format], self.obj) @staticmethod def load(filename, format): with open(filename, "rb") as f: data = f.read() return SSLCertificate.parse(data, format) @staticmethod def parse(data, format): cert = crypto.load_certificate(TypeMap[format], data) return SSLCertificate(cert) @staticmethod def generate(key): cert = crypto.X509() cert.set_pubkey(key.obj) cert.set_notBefore(b"20000101000000Z") cert.set_notAfter(b"29990101000000Z") subject = cert.get_subject() subject.commonName = "*" cert.sign(key.obj, "sha1") return SSLCertificate(cert) class SSLPrivateKey: def __init__(self, obj): self.obj = obj def encode(self, format): return crypto.dump_privatekey(TypeMap[format], self.obj) @staticmethod def load(filename, format): with open(filename, "rb") as f: data = f.read() return SSLPrivateKey.parse(data, format) @staticmethod def parse(data, format): pkey = crypto.load_privatekey(TypeMap[format], data) return SSLPrivateKey(pkey) @staticmethod def generate(): pkey = crypto.PKey() pkey.generate_key(crypto.TYPE_RSA, 1024) return SSLPrivateKey(pkey) class SSLContext: def __init__(self, version): self.context = ssl.SSLContext(VersionMap[version]) def set_certificate(self, cert, key): certfile = tempfile.NamedTemporaryFile() keyfile = tempfile.NamedTemporaryFile() certfile.write(cert.encode(TYPE_PEM)) keyfile.write(key.encode(TYPE_PEM)) certfile.flush() keyfile.flush() self.context.load_cert_chain(certfile.name, keyfile.name) certfile.close() keyfile.close() def wrap(self, sock, host=None): if host is None: return self.context.wrap_socket(sock, True) return self.context.wrap_socket(sock, False, server_hostname=host) class SSLClient: def __init__(self, version=VERSION_TLS12, sock=None): self.s = sock if not self.s: self.s = socket.TCPClient() self.context = SSLContext(version) def set_certificate(self, cert, key): self.context.set_certificate(cert, key) def connect(self, host, port, timeout=3): sock = self.context.wrap(self.s.fd(), host) wrapper = socket.SocketWrapper(sock) self.s = socket.TCPClient(wrapper) return self.s.connect(host, port, timeout) def send(self, data): self.s.send(data) def recv(self, num=4096): return self.s.recv(num) def close(self): self.s.close() def remote_certificate(self): cert = self.fd().getpeercert(True) if cert: return SSLCertificate.parse(cert, TYPE_DER) def fd(self): return self.s.fd() def local_address(self): return self.s.local_address() def remote_address(self): return self.s.remote_address() class SSLServer: def __init__(self, version=VERSION_TLS12, server=None): self.server = server if not self.server: self.server = socket.TCPServer() key = SSLPrivateKey.generate() cert = SSLCertificate.generate(key) self.context = SSLContext(version) self.context.set_certificate(cert, key) def set_certificate(self, cert, key): self.context.set_certificate(cert, key) def start(self, host, port): sock = self.context.wrap(self.server.fd()) wrapper = socket.SocketWrapper(sock) self.server = socket.TCPServer(wrapper) self.server.start(host, port) def accept(self): client = self.server.accept() if client: return SSLClient(sock=client)