import argparse import hmac import io import json import subprocess import time import traceback from typing import List, Tuple import flask import libvirt import xmltodict import yaml from flask import Flask, request class LibvirtConfig: def __init__(self, data: dict): self.uri: str = data["uri"] self.domain: str = data["domain"] class HTTPConfig: def __init__(self, data: dict): addr = data["address"].rsplit(':', maxsplit=1) self.host: str = addr[0].strip('[]') self.port: int = int(addr[1]) self._security = data["security"] @property def is_secure(self) -> bool: return self._security["enabled"] @property def secret(self) -> str: return self._security["secret"] class CommandsConfig: def __init__(self, data: dict): self.host_commands: List[str] = data.get("host", []) self.guest_commands: List[str] = data.get("guest", []) class Config: def __init__(self, data: dict): self.http = HTTPConfig(data["http"]) self.devices = [(d["vendor"], d["product"]) for d in data["devices"]] self.displays = data["displays"] self.libvirt = LibvirtConfig(data["libvirt"]) self.commands = CommandsConfig(data.get("commands", {})) @staticmethod def load(filename: str): with io.open(filename) as f: return Config(yaml.safe_load(f)) class Virt: def __init__(self, uri: str, domain: str): self._con = libvirt.open(uri) self._dom = self._con.lookupByName(domain) def get_devices(self) -> List[dict]: devs = [] desc = xmltodict.parse(self._dom.XMLDesc()) for dev in desc["domain"]["devices"]["hostdev"]: if dev["@type"] == "usb": devs.append(dev) return devs @staticmethod def get_device_ids(desc: dict) -> Tuple[int, int]: return (int(desc["source"]["vendor"]["@id"], 16), int(desc["source"]["product"]["@id"], 16)) def get_device_by_ids(self, ids: Tuple[int, int]) -> dict: for dev in self.get_devices(): if self.get_device_ids(dev) == ids: return dev return None def attach_devices(self, devs: List[dict]): for ids in devs: dev = self.get_device_by_ids(ids) if dev is None: dev = xmltodict.unparse({ "hostdev": { "@mode": "subsystem", "@type": "usb", "source": { "vendor": {"@id": hex(ids[0])}, "product": {"@id": hex(ids[1])} } } }) self._dom.attachDevice(dev) def detach_devices(self, devs: List[dict]): for dev in self.get_devices(): if self.get_device_ids(dev) in devs: xml = xmltodict.unparse({"hostdev": dev}) self._dom.detachDevice(xml) class Switch: def __init__(self, config: Config): self.config = config self.virt = Virt(config.libvirt.uri, config.libvirt.domain) @staticmethod def _call_dccutil(display: dict, ident: int): return subprocess.call([ "ddcutil", "--bus", str(display["bus"]), "setvcp", hex(display["feature"]), hex(ident) ]) @staticmethod def _call_commands(command: str): return subprocess.call(command, shell=True) def switch_to_host(self): for display in self.config.displays: self._call_dccutil(display, display["host"]) for command in self.config.commands.host_commands: self._call_commands(command) self.virt.detach_devices(self.config.devices) def switch_to_guest(self): for display in self.config.displays: self._call_dccutil(display, display["guest"]) for command in self.config.commands.guest_commands: self._call_commands(command) self.virt.attach_devices(self.config.devices) switch: Switch = None app = Flask(__name__) @app.route("/switch", methods=["POST"]) def app_switch(): if switch.config.http.is_secure: secret = request.headers.get("X-Secret") if secret is None \ or not hmac.compare_digest(switch.config.http.secret, secret): flask.abort(403) cases = { "host": switch.switch_to_host, "guest": switch.switch_to_guest } if not request.json \ or not "to" in request.json \ or not request.json["to"] in cases: flask.abort(400) error = None try: cases[request.json["to"]]() except: error = traceback.format_exc() return flask.jsonify({"success": True, "error": error}) def main(): parser = argparse.ArgumentParser(description="The poor man's KVM switch for libvirt and VFIO users", formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--config", dest="config", required=True, help="the YAML configuration file") args = parser.parse_args() global switch config = Config.load(args.config) switch = Switch(config) app.run(host=config.http.host, port=config.http.port)