import collections import logging import multiprocessing import numbers import pickle import signal import subprocess import six import sys import uuid from .. import constants from .. import __version__ from ..exceptions import DisconnectedError, SessionClosedError, UnknownMethodError logger = logging.getLogger(__name__) RAW_TYPES = six.string_types+(six.binary_type, numbers.Number, BaseException) try: from setproctitle import setproctitle except ImportError: logger.info('Cannot set process name.') setproctitle = lambda title: None class ClientSession(object): def __init__(self, transport): self.transport = transport self.connection = transport.client_get_connection() self._closed = True @property def closed(self): return self._closed def __getattr__(self, name): return RemoteObjWrapper(self, name) def __enter__(self): self._closed = False return self def __exit__(self, exc_type, exc_val, exc_tb): self._closed = True try: self.transport.client_close(self.connection) except Exception as e: logger.exception(e) # we don't want to hide AttributeErrors, etc. from end user return False class RemoteObjRef(object): def __init__(self, name): if hasattr(name, 'decode'): name = name.decode('utf-8') self.name = name def __str__(self): return self.name class RemoteObjWrapper(object): _patch_functions = ['next', '__next__', '__iter__'] def __init__(self, session, name): self.session = session self.name = name self._property_cache = {} for name in self._patch_functions: self._property_cache[name] = None def _send(self, method, *args, **kwargs): session = self.session if session.closed: raise SessionClosedError() if method == 'GET': func = session.transport.send_get_request elif method == 'CALL': func = session.transport.send_call_request else: raise UnknownMethodError(method) ret = func(session.connection, self.name, *args, **kwargs) if isinstance(ret, RemoteObjRef): ret = RemoteObjWrapper(session, ret.name) return ret def __getattr__(self, name): return self._send('GET', name) def __call__(self, *args, **kwargs): return self._send('CALL', *args, **kwargs) def _get_prop(self, name): val = self._property_cache[name] if val: return val self._property_cache[name] = self.__getattr__(name) return self._property_cache[name] @property def next(self): return self._get_prop('next') @property def __next__(self): return self._get_prop('__next__') @property def __iter__(self): return self._get_prop('__iter__') class Request(object): def __init__(self, method, path, headers, body): self.method = method self.path = path self.headers = headers self.body = body class Response(object): def __init__(self, status, headers, body): self.status = status self.headers = headers self.body = body def worker_init(*args): name = multiprocessing.current_process().name logger.debug('Worker initialized: {}'.format(name)) signal.signal(signal.SIGINT, signal.SIG_IGN) setproctitle('errand-boy worker process {}'.format(name.split('-')[1])) def worker(self, connection): logger.debug('worker connected') self.server_handle_client(connection) class BaseTransport(object): """ Base class providing functionality common to all transports. """ def __init__(self): pass def connection_to_string(self, connection): return repr(connection) def server_get_connection(self): raise NotImplementedError() def server_recv(self, connection): raise NotImplementedError() def server_send(self, connection, data): raise NotImplementedError() def server_close(self, connection): pass def translate_obj(self, exposed_locals, val): if isinstance(val, RemoteObjRef): val = exposed_locals[val.name] return val def server_serialize(self, exposed_locals, obj): if obj is not None and not isinstance(obj, RAW_TYPES): name = six.text_type(uuid.uuid4()) exposed_locals[name] = obj obj = RemoteObjRef(name) return obj def server_handle_client(self, connection): connection = self.server_deserialize_connection(connection) logger.debug('server_handle_client: {}'.format(self.connection_to_string(connection))) exposed_locals = {'subprocess': subprocess} while True: # need to close connection when client not listening try: request = self.get_request(connection) except DisconnectedError: break raised = False obj = None if request.method == 'GET': name, attr = request.path.split('.') try: obj = getattr(exposed_locals[name], attr) except KeyError as e: obj = e raised = True elif request.method == 'CALL': name = request.path try: obj = exposed_locals[name] except KeyError as e: obj = e raised = True args, kwargs = pickle.loads(request.body) args = [self.translate_obj(exposed_locals, arg) for arg in args] for key in kwargs: kwargs[key] = self.translate_obj(exposed_locals, kwargs[key]) try: obj = obj(*args, **kwargs) except Exception as e: obj = e raised = True obj = self.server_serialize(exposed_locals, obj) self.send_response(connection, obj, raised=raised) self.server_close(connection) def server_accept(self, serverconnection): raise NotImplementedError() def server_deserialize_connection(self, connection): return connection def server_serialize_connection(self, connection): return connection def run_server(self, pool_size=10, max_accepts=5000, max_child_tasks=100): setproctitle('errand-boy master process') serverconnection = self.server_get_connection() logger.info('Accepting connections: {}'.format(self.connection_to_string(serverconnection))) logger.info('pool_size: {}'.format(pool_size)) logger.info('max_accepts: {}'.format(max_accepts)) logger.info('max_child_tasks: {}'.format(max_child_tasks)) pool = multiprocessing.Pool(pool_size, worker_init, tuple(), max_child_tasks) connections = [] remaining_accepts = max_accepts if not remaining_accepts: remaining_accepts = True try: while remaining_accepts: connection = self.server_accept(serverconnection) logger.info('Accepted connection from: {}'.format(self.connection_to_string(connection))) result = pool.apply_async(worker, [self, self.server_serialize_connection(connection)]) connection = None if remaining_accepts is not True: remaining_accepts -= 1 except KeyboardInterrupt: logger.info('Received KeyboardInterrupt') pool.terminate() except Exception as e: logger.exception(e) pool.terminate() raise finally: pool.close() pool.join() def client_get_connection(self): raise NotImplementedError() def client_recv(self, connection): raise NotImplementedError() def client_send(self, connection, command_string): raise NotImplementedError() def client_close(self, connection): pass def send_algo(self, connection, send_func, first_line, headers=None, body=None): CRLF = constants.CRLF msg = [first_line] msg.append(CRLF) if headers: for name, val in headers: msg.append('{}: {}'.format(name, val)) msg.append(CRLF) msg.append('Content-Length: {}'.format(len(body))) msg.append(CRLF) if body: msg.append(CRLF) msg.append(body) msg = [s.encode('utf-8') if hasattr(s, 'encode') else s for s in msg] msg = b''.join(msg) return send_func(connection, msg) def send_request(self, connection, method, path, body=''): first_line = "{method} {path}".format(method=method, path=path) self.send_algo(connection, self.client_send, first_line, headers=None, body=body) resp = self.get_response(connection) obj = pickle.loads(resp.body) if resp.status == 400: raise obj return obj def send_get_request(self, connection, prefix, name): return self.send_request(connection, 'GET', prefix+'.'+name) def send_call_request(self, connection, name, *args, **kwargs): kwargs = collections.OrderedDict(sorted(kwargs.items(), key=lambda t: t[0])) body = pickle.dumps([args, kwargs]) return self.send_request(connection, 'CALL', name, body=body) def recv_algo(self, connection, recv_func): CRLF = constants.CRLF lines = [] data = b'' content_length = None while True: new_data = recv_func(connection, 4096) if not new_data: raise DisconnectedError() data += new_data if not lines and CRLF in data: try: headers, body = data.split(CRLF + CRLF, 1) except ValueError: split_data = data.split(CRLF) lines.extend(split_data[:-1]) data = split_data[-1] else: lines.extend((headers + CRLF).split(CRLF)) data = body if lines and content_length is None: for line in lines: try: header, val = line.split(b': ') if header.lower() == b'content-length': content_length = int(val) break except ValueError: pass if content_length == 0 and len(lines) > 1: break if len(lines) > 2 and lines[-1] == b'': # needs length of body, minus already fetched data remaining_len = content_length - len(data) if remaining_len: data += recv_func(connection, remaining_len) # only break once all data has been returned if len(data) == content_length: break if data: data = CRLF + data data = CRLF.join(lines) + data try: headers, body = data.split(CRLF+CRLF) except ValueError: headers = data body = b'' headers = headers.split(CRLF) headers = [header.decode('utf-8') for header in headers] first_line = headers[0] headers = headers[1:] headers = [header.split(': ') for header in headers] return first_line, headers, body def get_request(self, connection): first_line, headers, body = self.recv_algo(connection, self.server_recv) method, path = first_line.split(' ', 1) return Request(method, path, headers, body) def send_response(self, connection, obj, raised=False): body = pickle.dumps(obj) first_line = '200 OK' if not raised else '400 Error' return self.send_algo(connection, self.server_send, first_line, body=body) def get_response(self, connection): first_line, headers, body = self.recv_algo(connection, self.client_recv) status = int(first_line.split()[0]) return Response(status, headers, body) def get_session(self): return ClientSession(self) def run_cmd(self, command_string): with self.get_session() as session: subprocess = session.subprocess process = subprocess.Popen(command_string, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) stdout, stderr = process.communicate() returncode = process.returncode return stdout, stderr, returncode