# Copyright 2015 Internap. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from _socket import timeout, gaierror import logging import time import paramiko from netman.adapters import shell from netman.adapters.shell.base import TerminalClient from netman.core.objects.exceptions import CouldNotConnect, ConnectTimeout, CommandTimeout class SshClient(TerminalClient): def __init__(self, host, username, password, port=22, prompt=('>', '#'), connect_timeout=None, command_timeout=None, reading_interval=0.01, reading_chunk_size=9999): self.logger = logging.getLogger(__name__) self.host = host self.port = port self.username = username self.prompt = prompt self.command_timeout = command_timeout or shell.default_command_timeout connect_timeout = connect_timeout or shell.default_connect_timeout self.reading_interval = reading_interval self.reading_chunk_size = reading_chunk_size self.current_buffer = '' self.client = None self.channel = None self.full_log = "" self._open_channel(host, port, username, password, connect_timeout) def do(self, command, wait_for=None, include_last_line=False): self.logger.debug("[SSH][{}@{}:{}] Send >> {}".format(self.username, self.host, self.port, command)) self.channel.send(command + '\n') return self._read_until(wait_for, include_last_line) def send_key(self, key, wait_for=None, include_last_line=False): self.logger.debug("[SSH][{}@{}:{}] Send KEY >> {}".format(self.username, self.host, self.port, key)) self.channel.send(key) return self._read_until(wait_for, include_last_line) def quit(self, command): self.logger.debug("[SSH][{}@{}:{}] Quit >> {}".format(self.username, self.host, self.port, command)) self.channel.send(command + '\n') def get_current_prompt(self): return self.current_buffer.splitlines()[-1] def _open_channel(self, host, port, username, password, connect_timeout): self.client = paramiko.SSHClient() self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) try: self.client.connect(host, port=port, username=username, password=password, timeout=connect_timeout, allow_agent=False, look_for_keys=False) except timeout: raise ConnectTimeout(host, port) except gaierror: raise CouldNotConnect(host, port) self.channel = self.client.invoke_shell() self._wait_for(self.prompt) def _read_until(self, wait_for, include_last_line): self._wait_for(wait_for or self.prompt) lines = self.current_buffer.splitlines()[1:] if not include_last_line: lines = lines[:-1] return filter(None, lines) def _wait_for(self, wait_for): self.current_buffer = '' started_at = time.time() while not self.current_buffer.endswith(wait_for): while not self.channel.recv_ready(): if time.time() - started_at > self.command_timeout: raise CommandTimeout(wait_for, self.current_buffer) time.sleep(self.reading_interval) read = self.channel.recv(self.reading_chunk_size) self.logger.debug("[SSH][{}@{}:{}] Recv << {}".format(self.username, self.host, self.port, repr(read))) self.full_log += read self.current_buffer += read