import socket import struct import time import OpenSSL SSLError = OpenSSL.SSL.WantReadError from pyasn1.codec.der import decoder as der_decoder import socks import utils from subj_alt_name import SubjectAltName from . import openssl_wrap class ConnectCreator(object): def __init__(self, logger, config, openssl_context, host_manager, timeout=5, debug=False, check_cert=None): self.logger = logger self.config = config self.openssl_context = openssl_context self.host_manager = host_manager self.timeout = timeout self.debug = debug if check_cert: self.check_cert = check_cert self.update_config() self.connect_force_http1 = self.config.connect_force_http1 self.connect_force_http2 = self.config.connect_force_http2 def update_config(self): if int(self.config.PROXY_ENABLE): if self.config.PROXY_TYPE == "HTTP": proxy_type = socks.HTTP elif self.config.PROXY_TYPE == "SOCKS4": proxy_type = socks.SOCKS4 elif self.config.PROXY_TYPE == "SOCKS5": proxy_type = socks.SOCKS5 else: self.logger.error("proxy type %s unknown, disable proxy", self.config.PROXY_TYPE) raise Exception() socks.set_default_proxy(proxy_type, self.config.PROXY_HOST, self.config.PROXY_PORT, self.config.PROXY_USER, self.config.PROXY_PASSWD) @staticmethod def get_subj_alt_name(peer_cert): ''' Copied from ndg.httpsclient.ssl_peer_verification.ServerSSLCertVerification Extract subjectAltName DNS name settings from certificate extensions @param peer_cert: peer certificate in SSL connection. subjectAltName settings if any will be extracted from this @type peer_cert: OpenSSL.crypto.X509 ''' # Search through extensions dns_name = [] general_names = SubjectAltName() for i in range(peer_cert.get_extension_count()): ext = peer_cert.get_extension(i) ext_name = ext.get_short_name() if ext_name == b"subjectAltName": # PyOpenSSL returns extension data in ASN.1 encoded form ext_dat = ext.get_data() decoded_dat = der_decoder.decode(ext_dat, asn1Spec=general_names) for name in decoded_dat: if isinstance(name, SubjectAltName): for entry in range(len(name)): component = name.getComponentByPosition(entry) n = bytes(component.getComponent()) if n.startswith(b"*"): continue dns_name.append(n) return dns_name def get_ssl_cert_domain(self, ssl_sock): cert = ssl_sock.get_peer_certificate() if not cert: raise SSLError("no cert") ssl_cert = openssl_wrap.SSLCert(cert) ssl_sock.domain = ssl_cert.cn def connect_ssl(self, ip_str, sni=b"", close_cb=None): if sni: host = sni else: sni, host = self.host_manager.get_sni_host(ip_str) host = str(host) if isinstance(sni, str): sni = bytes(sni, encoding='ascii') ip, port = utils.get_ip_port(ip_str) if isinstance(ip, str): ip = bytes(ip, encoding='ascii') if int(self.config.PROXY_ENABLE): sock = socks.socksocket(socket.AF_INET if b':' not in ip else socket.AF_INET6) else: sock = socket.socket(socket.AF_INET if b':' not in ip else socket.AF_INET6) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) # set struct linger{l_onoff=1,l_linger=0} to avoid 10048 socket error sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, struct.pack('ii', 1, 0)) # resize socket recv buffer ->64 above to improve browser releated application performance sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, self.config.connect_receive_buffer) sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, True) sock.settimeout(self.timeout) ssl_sock = openssl_wrap.SSLConnection(self.openssl_context.context, sock, ip_str, on_close=close_cb) ssl_sock.set_connect_state() if sni: if self.debug: self.logger.debug("sni:%s", sni) try: ssl_sock.set_tlsext_host_name(sni) except Exception as e: self.logger.exception("set_tlsext_host_name %s except:%r", sni, e) pass time_begin = time.time() ip_port = (ip, port) try: ssl_sock.connect(ip_port) time_connected = time.time() ssl_sock.do_handshake() except Exception as e: raise socket.error('conn fail, sni:%s, top:%s e:%r' % (sni, host, e)) if self.connect_force_http1: ssl_sock.h2 = False elif self.connect_force_http2: ssl_sock.h2 = True else: try: h2 = ssl_sock.get_alpn_proto_negotiated() if h2 == b"h2": ssl_sock.h2 = True else: ssl_sock.h2 = False except Exception as e: # xlog.exception("alpn:%r", e) if hasattr(ssl_sock._connection, "protos") and ssl_sock._connection.protos == "h2": ssl_sock.h2 = True else: ssl_sock.h2 = False time_handshaked = time.time() ssl_sock.sni = sni self.check_cert(ssl_sock) connect_time = int((time_connected - time_begin) * 1000) handshake_time = int((time_handshaked - time_begin) * 1000) # sometimes, we want to use raw tcp socket directly(select/epoll), so setattr it to ssl socket. ssl_sock.ip_str = ip_str #ssl_sock.ip = ip ssl_sock._sock = sock ssl_sock.fd = sock.fileno() ssl_sock.create_time = time_begin ssl_sock.connect_time = connect_time ssl_sock.handshake_time = handshake_time ssl_sock.last_use_time = time_handshaked ssl_sock.host = host ssl_sock.received_size = 0 return ssl_sock def check_cert(self, ssl_sock): cert_chain = ssl_sock.get_peer_cert_chain() if not cert_chain: raise socket.error('certificate is none, sni:%s' % ssl_sock.sni) if len(cert_chain) < self.config.min_intermediate_CA: raise socket.error('No intermediate CA was found.') if self.config.check_pkp and hasattr(OpenSSL.crypto, "dump_publickey"): # old OpenSSL not support this function. pub_key = OpenSSL.crypto.dump_publickey(OpenSSL.crypto.FILETYPE_PEM, cert_chain[1].get_pubkey()) if pub_key not in self.config.CHECK_PKP: # google_ip.report_connect_fail(ip, force_remove=True) raise socket.error('The intermediate CA is mismatching.') self.get_ssl_cert_domain(ssl_sock) issuer_commonname = next((v for k, v in cert_chain[0].get_issuer().get_components() if k == b'CN'), '') if self.debug: for cert in cert_chain: for k, v in cert.get_issuer().get_components(): if k != b"CN": continue cn = v self.logger.debug("cn:%s", cn) self.logger.debug("issued by:%s", issuer_commonname) self.logger.debug("Common Name:%s", ssl_sock.domain) if self.config.check_commonname and not issuer_commonname.startswith(self.config.check_commonname): raise socket.error(' certificate is issued by %r' % (issuer_commonname)) cert = ssl_sock.get_peer_certificate() if not cert: raise socket.error('certificate is none') if self.config.check_sni: # get_subj_alt_name cost near 100ms. be careful. try: alt_names = ConnectCreator.get_subj_alt_name(cert) except Exception as e: # self.logger.warn("get_subj_alt_name fail:%r", e) alt_names = [b""] if self.debug: self.logger.debug('alt names: "%s"', b'", "'.join(alt_names)) if isinstance(self.config.check_sni, str): if self.config.check_sni not in alt_names: raise socket.error('check sni fail, alt_names:%s' % (alt_names)) else: alt_names = tuple(alt_names) if not ssl_sock.sni.endswith(alt_names): raise socket.error('check sni:%s fail, alt_names:%s' % (ssl_sock.sni, alt_names))