# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. # See the License for the specific language governing permissions and # limitations under the License. import contextlib import logging import os import re import select import socket import subprocess import sys import threading import time from distutils import spawn import fixtures import jinja2 import psutil import six from pifpaf import util try: import xattr except ImportError: xattr = None if six.PY3: fsdecode = os.fsdecode else: def fsdecode(s): if isinstance(s, unicode): # noqa return s return s.decode(sys.getfilesystemencoding()) LOG = logging.getLogger(__name__) if six.PY2: # This is defined only on Python 3 ProcessLookupError = None # noqa class Driver(fixtures.Fixture): def __init__(self, env_prefix="PIFPAF", templatedir=".", debug=False, tmp_rootdir=None): """Create a new driver.""" super(Driver, self).__init__() self.env_prefix = env_prefix self.env = {} self.debug = debug self.tmp_rootdir = tmp_rootdir templatedir = os.path.join('drivers', 'templates', templatedir) self.template_env = jinja2.Environment( loader=jinja2.PackageLoader('pifpaf', templatedir)) def _setUp(self): self.tempdir = self.useFixture(fixtures.TempDir(self.tmp_rootdir)).path self.putenv("DATA", self.tempdir) @staticmethod def get_options(): return [] def putenv(self, key, value, raw=False): if not raw: key = self.env_prefix + "_" + key self.env[key] = value return self.useFixture(fixtures.EnvironmentVariable(key, value)) def _ensure_xattr_support(self): testfile = os.path.join(self.tempdir, "test") self._touch(testfile) xattr_supported = False if xattr is not None: try: x = xattr.xattr(testfile) x[b"user.test"] = b"test" except (OSError, IOError) as e: if e.errno != 95: raise else: xattr_supported = True if not xattr_supported: raise RuntimeError("TMPDIR must support xattr for %s" % self.__class__.__name__) def _kill(self, parent): log_thread = getattr(parent, "_log_thread", None) util.process_cleaner(parent) if log_thread: # Parent process have been killed log_thread.join(timeout=3) if log_thread.is_alive(): LOG.warning("logging thread for `%s` is still alive", parent) @staticmethod def find_executable(filename, extra_paths): paths = extra_paths + os.getenv('PATH', os.defpath).split(os.pathsep) for path in paths: loc = spawn.find_executable(filename, path) if loc is not None: return loc @staticmethod def find_config_file(filename): # NOTE(sileht): order matter, we first check into virtualenv # then global user installation, next system installation, # and to finish local user installation check_dirs = [sys.prefix + "/etc", "/usr/local/etc", "/etc", os.path.expanduser("~/.local/etc")] for d in check_dirs: fullpath = os.path.join(d, filename) if os.path.exists(fullpath): return fullpath raise RuntimeError("Configuration file `%s' not found" % filename) def _read_in_bg(self, app, pid, fd): while True: data = fd.readline() if not data: break self._log_output(app, pid, data) fd.close() @staticmethod def _log_output(appname, pid, data): data = fsdecode(data) LOG.debug("%s[%d] output: %s", appname, pid, data.rstrip()) def _exec(self, command, stdout=False, ignore_failure=False, stdin=None, wait_for_line=None, wait_for_port=None, path=[], env=None, forbidden_line_after_start=None, allow_debug=True): LOG.debug("executing: %s", command) app = command[0] debug = allow_debug and LOG.isEnabledFor(logging.DEBUG) if stdout or wait_for_line or debug: stdout_fd = subprocess.PIPE else: # TODO(jd) Need to close at some point stdout_fd = open(os.devnull, 'w') if stdin: stdin_fd = subprocess.PIPE else: # TODO(jd) Need to close at some point stdin_fd = open(os.devnull, 'r') if path or env: complete_env = dict(os.environ) if env: complete_env.update(env) if path: complete_env.update({ "PATH": ":".join(path) + ":" + os.getenv("PATH", ""), }) else: complete_env = None try: c = psutil.Popen( command, close_fds=True, stdin=stdin_fd, stdout=stdout_fd, stderr=subprocess.STDOUT, env=complete_env, preexec_fn=os.setsid, ) except OSError as e: raise RuntimeError( "Unable to run command `%s': %s" % (" ".join(command), e)) self.addCleanup(self._kill, c) if stdin: LOG.debug("%s input: %s", app, stdin) c.stdin.write(stdin) c.stdin.close() if stdout or wait_for_line: lines = [] while True: line = c.stdout.readline() self._log_output(app, c.pid, line) lines.append(line) if not line: if wait_for_line: raise RuntimeError( "Program did not print: `%s'\nOutput: %s" % (wait_for_line, b"".join(lines))) break decoded_line = fsdecode(line) if wait_for_line and re.search(wait_for_line, decoded_line): break stdout_str = b"".join(lines) else: stdout_str = None if (stdout or wait_for_line) and forbidden_line_after_start: timeout, forbidden_output = forbidden_line_after_start r, w, x = select.select([c.stdout.fileno()], [], [], timeout) if r: line = c.stdout.readline() self._log_output(app, c.pid, line) lines.append(line) if c.poll() is not None: # Read the rest if the process is dead, this help debugging while line: line = c.stdout.readline() self._log_output(app, c.pid, line) lines.append(line) if line and re.search(forbidden_output, fsdecode(line)): raise RuntimeError( "Program print a forbidden line: `%s'\nOutput: %s" % (forbidden_output, b"".join(lines))) if stdout or wait_for_line or debug: # Continue to read t = threading.Thread(target=self._read_in_bg, args=(app, c.pid, c.stdout,)) t.setDaemon(True) t.start() # Store the thread ref into the Process() to be able # to clean it c._log_thread = t if wait_for_port: for i in range(0, 10): with contextlib.closing( socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: if sock.connect_ex(('127.0.0.1', wait_for_port)) == 0: break time.sleep(1) else: raise RuntimeError("Program did not open port %s" % wait_for_port) if not wait_for_line and not wait_for_port: status = c.wait() if not ignore_failure and status != 0: raise RuntimeError("Error while running command: %s" % command) return c, stdout_str def _touch(self, fname): open(fname, 'a').close() os.utime(fname, None) def template(self, resource, env, dest): template = self.template_env.get_template(resource) with open(dest, 'w') as f: f.write(template.render(**env))