import paramiko from paramiko import AuthenticationException, SSHException, ChannelException from enum import Enum from concurrent.futures import ThreadPoolExecutor from socket import error as SocketError from margaritashotgun.auth import AuthMethods from margaritashotgun.exceptions import * import logging logger = logging.getLogger(__name__) class Commands(Enum): mem_size = "cat /proc/meminfo | grep MemTotal | awk '{ print $2 }'" kernel_version = "uname -r" lime_pattern = "{0}:{1}" lime_check = "cat /proc/net/tcp" load_lime = 'sudo insmod {0} "path=tcp:{1}" format={2}' unload_lime = "sudo pkill insmod; sudo rmmod lime" class RemoteShell(): def __init__(self, max_async_threads=2): """ :type args: int :param args: maximun number of async command executors """ self.jump_host_ssh = None self.ssh = paramiko.SSHClient() self.ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) self.executor = ThreadPoolExecutor(max_workers=max_async_threads) self.futures = [] def connect(self, auth, address, port, jump_host, jump_auth): """ Creates an ssh session to a remote host :type auth: :py:class:`margaritashotgun.auth.AuthMethods` :param auth: Authentication object :type address: str :param address: remote server address :type port: int :param port: remote server port """ try: self.target_address = address sock = None if jump_host is not None: self.jump_host_ssh = paramiko.SSHClient() self.jump_host_ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) self.connect_with_auth(self.jump_host_ssh, jump_auth, jump_host['addr'], jump_host['port'], sock) transport = self.jump_host_ssh.get_transport() dest_addr = (address, port) jump_addr = (jump_host['addr'], jump_host['port']) channel = transport.open_channel('direct-tcpip', dest_addr, jump_addr) self.connect_with_auth(self.ssh, auth, address, port, channel) else: self.connect_with_auth(self.ssh, auth, address, port, sock) except (AuthenticationException, SSHException, ChannelException, SocketError) as ex: raise SSHConnectionError("{0}:{1}".format(address, port), ex) def connect_with_auth(self, ssh, auth, address, port, sock): """ """ logger.debug(("{0}: paramiko client connecting to " "{0}:{1} with {2}".format(address, port, auth.method))) if auth.method == AuthMethods.key: self.connect_with_key(ssh, auth.username, auth.key, address, port, sock) elif auth.method == AuthMethods.password: self.connect_with_password(ssh, auth.username, auth.password, address, port, sock) else: raise AuthenticationMethodMissingError() logger.debug(("{0}: paramiko client connected to " "{0}:{1}".format(address, port))) def connect_with_password(self, ssh, username, password, address, port, sock, timeout=20): """ Create an ssh session to a remote host with a username and password :type username: str :param username: username used for ssh authentication :type password: str :param password: password used for ssh authentication :type address: str :param address: remote server address :type port: int :param port: remote server port """ ssh.connect(username=username, password=password, hostname=address, port=port, sock=sock, timeout=timeout) def connect_with_key(self, ssh, username, key, address, port, sock, timeout=20): """ Create an ssh session to a remote host with a username and rsa key :type username: str :param username: username used for ssh authentication :type key: :py:class:`paramiko.key.RSAKey` :param key: paramiko rsa key used for ssh authentication :type address: str :param address: remote server address :type port: int :param port: remote server port """ ssh.connect(hostname=address, port=port, username=username, pkey=key, sock=sock, timeout=timeout) def transport(self): transport = self.ssh.get_transport() transport.use_compression(True) transport.window_size = 2147483647 transport.packetizer.REKEY_BYTES = pow(2, 40) transport.packetizer.REKEY_PACKETS = pow(2, 40) return self.ssh.get_transport() def execute(self, command): """ Executes command on remote hosts :type command: str :param command: command to be run on remote host """ try: if self.ssh.get_transport() is not None: logger.debug('{0}: executing "{1}"'.format(self.target_address, command)) stdin, stdout, stderr = self.ssh.exec_command(command) return dict(zip(['stdin', 'stdout', 'stderr'], [stdin, stdout, stderr])) else: raise SSHConnectionError(self.target_address, "ssh transport is closed") except (AuthenticationException, SSHException, ChannelException, SocketError) as ex: logger.critical(("{0} execution failed on {1} with exception:" "{2}".format(command, self.target_address, ex))) raise SSHCommandError(self.target_address, command, ex) def execute_async(self, command, callback=None): """ Executes command on remote hosts without blocking :type command: str :param command: command to be run on remote host :type callback: function :param callback: function to call when execution completes """ try: logger.debug(('{0}: execute async "{1}"' 'with callback {2}'.format(self.target_address, command, callback))) future = self.executor.submit(self.execute, command) if callback is not None: future.add_done_callback(callback) return future except (AuthenticationException, SSHException, ChannelException, SocketError) as ex: logger.critical(("{0} execution failed on {1} with exception:" "{2}".format(command, self.target_address, ex))) raise SSHCommandError(self.target_address, command, ex) def decode(self, stream, encoding='utf-8'): """ Convert paramiko stream into a string :type stream: :param stream: stream to convert :type encoding: str :param encoding: stream encoding """ data = stream.read().decode(encoding).strip("\n") if data != "": logger.debug(('{0}: decoded "{1}" with encoding ' '{2}'.format(self.target_address, data, encoding))) return data def upload_file(self, local_path, remote_path): """ Upload a file from the local filesystem to the remote host :type local_path: str :param local_path: path of local file to upload :type remote_path: str :param remote_path: destination path of upload on remote host """ logger.debug("{0}: uploading {1} to {0}:{2}".format(self.target_address, local_path, remote_path)) try: sftp = paramiko.SFTPClient.from_transport(self.transport()) sftp.put(local_path, remote_path) sftp.close() except SSHException as ex: logger.warn(("{0}: LiME module upload failed with exception:" "{1}".format(self.target_address, ex))) def cleanup(self): """ Release resources used during shell execution """ for future in self.futures: future.cancel() self.executor.shutdown(wait=10) if self.ssh.get_transport() != None: self.ssh.close()