import logging import os import select import signal import subprocess import sys import threading from typing import List, Optional from dmoj.cptbox._cptbox import * from dmoj.cptbox.handlers import ALLOW, DISALLOW, _CALLBACK from dmoj.cptbox.syscalls import SYSCALL_COUNT, by_id, translator from dmoj.error import InternalError from dmoj.utils.communicate import safe_communicate as _safe_communicate from dmoj.utils.os_ext import ( ARCH_A64, ARCH_ARM, ARCH_X32, ARCH_X64, ARCH_X86, INTERPRETER_ARCH, file_arch, find_exe_in_path, oom_score_adj, OOM_SCORE_ADJ_MAX, ) from dmoj.utils.unicode import utf8bytes, utf8text PIPE = subprocess.PIPE log = logging.getLogger('dmoj.cptbox') _PIPE_BUF = getattr(select, 'PIPE_BUF', 512) _SYSCALL_INDICIES: List[Optional[int]] = [None] * 7 if 'freebsd' in sys.platform: _SYSCALL_INDICIES[DEBUGGER_X64] = 4 else: _SYSCALL_INDICIES[DEBUGGER_X86] = 0 _SYSCALL_INDICIES[DEBUGGER_X86_ON_X64] = 0 _SYSCALL_INDICIES[DEBUGGER_X64] = 1 _SYSCALL_INDICIES[DEBUGGER_X32] = 2 _SYSCALL_INDICIES[DEBUGGER_ARM] = 3 _SYSCALL_INDICIES[DEBUGGER_ARM64] = 5 # (python arch, executable arch) -> debugger _arch_map = { (ARCH_X86, ARCH_X86): DEBUGGER_X86, (ARCH_X64, ARCH_X64): DEBUGGER_X64, (ARCH_X64, ARCH_X86): DEBUGGER_X86_ON_X64, (ARCH_X64, ARCH_X32): DEBUGGER_X32, (ARCH_X32, ARCH_X32): DEBUGGER_X32, (ARCH_X32, ARCH_X86): DEBUGGER_X86_ON_X64, (ARCH_ARM, ARCH_ARM): DEBUGGER_ARM, (ARCH_A64, ARCH_ARM): DEBUGGER_ARM, (ARCH_A64, ARCH_A64): DEBUGGER_ARM64, } class MaxLengthExceeded(ValueError): pass class AdvancedDebugger(Debugger): # Implements additional debugging functionality for convenience. @property def syscall_name(self): return self.get_syscall_name(self.syscall) def get_syscall_name(self, syscall): callname = 'unknown' index = self._syscall_index for id, call in enumerate(translator): if syscall in call[index]: callname = by_id[id] break return callname def readstr(self, address, max_size=4096): if self.address_bits == 32: address &= 0xFFFFFFFF read = super().readstr(address, max_size + 1) if read is None: return None if len(read) > max_size: raise MaxLengthExceeded(read) return utf8text(read) # SecurePopen is a subclass of a cython class, _cptbox.Process. Since it is exceedingly unwise # to do everything in cython, determining the debugger class is left to do here. However, since # the debugger is constructed in __cinit__, we have to pass the determined debugger class to # SecurePopen.__new__. While we can simply override __new__, many complication arises from having # different parameters to __new__ and __init__, the latter of which is given the *original* arguments # as passed to type.__call__. Hence, we use a metaclass to pass the extra debugger argument to both # __new__ and __init__. class TracedPopenMeta(type): def __call__(self, argv, executable=None, *args, **kwargs): executable = executable or find_exe_in_path(argv[0]) arch = file_arch(executable) debugger = _arch_map.get((INTERPRETER_ARCH, arch)) if debugger is None: raise RuntimeError('Executable type %s could not be debugged on Python type %s' % (arch, INTERPRETER_ARCH)) return super().__call__(debugger, self.debugger_type, argv, executable, *args, **kwargs) class TracedPopen(Process, metaclass=TracedPopenMeta): debugger_type = AdvancedDebugger def __init__( self, debugger, _, args, executable=None, security=None, time=0, memory=0, stdin=PIPE, stdout=PIPE, stderr=None, env=None, nproc=0, fsize=0, address_grace=4096, data_grace=0, personality=0, cwd='', fds=None, wall_time=None, ): self._debugger_type = debugger self._syscall_index = index = _SYSCALL_INDICIES[debugger] self._executable = executable or find_exe_in_path(args[0]) self._args = args self._chdir = cwd self._env = [ utf8bytes('%s=%s' % (arg, val)) for arg, val in (env if env is not None else os.environ).items() if val is not None ] self._time = time self._wall_time = time * 3 if wall_time is None else wall_time self._cpu_time = time + 5 if time else 0 self._memory = memory self._child_personality = personality self._child_memory = memory * 1024 + data_grace * 1024 self._child_address = memory * 1024 + address_grace * 1024 if memory else 0 self._nproc = nproc self._fsize = fsize self._is_tle = False self._is_ole = False self._fds = fds self.__init_streams(stdin, stdout, stderr) self.protection_fault = None self.debugger._syscall_index = index self.debugger.address_bits = 64 if debugger in (DEBUGGER_X64, DEBUGGER_ARM64) else 32 self._security = security self._callbacks = [None] * MAX_SYSCALL_NUMBER self._syscall_whitelist = [False] * MAX_SYSCALL_NUMBER if security is None: self._trace_syscalls = False else: for i in range(SYSCALL_COUNT): handler = security.get(i, DISALLOW) for call in translator[i][index]: if call is None: continue if isinstance(handler, int): self._syscall_whitelist[call] = handler == ALLOW else: if not callable(handler): raise ValueError('Handler not callable: ' + handler) self._callbacks[call] = handler handler = _CALLBACK self._handler(call, handler) self._died = threading.Event() self._spawned_or_errored = threading.Event() self._spawn_error = None if time: # Spawn thread to kill process after it times out self._shocker = threading.Thread(target=self._shocker_thread) self._shocker.start() self._worker = threading.Thread(target=self._run_process) self._worker.start() self._spawned_or_errored.wait() if self._spawn_error: raise self._spawn_error def wait(self): self._died.wait() if not self.was_initialized: if self.returncode == 203: raise RuntimeError('failed to set up seccomp policy') elif self.returncode == 204: raise RuntimeError( 'failed to ptrace child, check Yama config ' '(https://www.kernel.org/doc/Documentation/security/Yama.txt, should be ' 'at most 1); if running in Docker, must run container with `--cap-add=SYS_PTRACE`' ) elif self.returncode == 205: raise RuntimeError('failed to spawn child') return self.returncode def poll(self): return self.returncode def mark_ole(self): self._is_ole = True @property def is_ir(self): return self.returncode > 0 @property def is_mle(self): return self._memory and self.max_memory > self._memory @property def is_ole(self): return self._is_ole @property def is_rte(self): return self.returncode is None or self.returncode < 0 # Killed by signal @property def is_tle(self): return self._is_tle def kill(self): # FIXME(quantum): this is actually a race. The process may exit before we kill it. # Under very unlikely circumstances, the pid could be reused and we will end up # killing the wrong process. if self.returncode is None: log.warning('Request the killing of process: %s', self.pid) try: os.killpg(self.pid, signal.SIGKILL) except OSError: import traceback traceback.print_exc() else: log.warning('Skipping the killing of process because it already exited: %s', self.pid) def _callback(self, syscall): try: callback = self._callbacks[syscall] except IndexError: if self._syscall_index == 3: # ARM-specific return 0xF0000 < syscall < 0xF0006 return False if callback is not None: return callback(self.debugger) return False def _protection_fault(self, syscall): # When signed, 0xFFFFFFFF is equal to -1, meaning that ptrace failed to read the syscall for some reason. # We can't continue debugging as this could potentially be unsafe, so we should exit loudly. # See <https://github.com/DMOJ/judge/issues/181> for more details. if syscall == 0xFFFFFFFF: raise InternalError('ptrace failed') # TODO: this would be more useful if we had access to a proper errno # import errno, os # err = ... # raise InternalError('ptrace error: %d (%s: %s)' % (err, errno.errorcode[err], os.strerror(err))) else: callname = self.debugger.get_syscall_name(syscall) self.protection_fault = ( syscall, callname, [ self.debugger.uarg0, self.debugger.uarg1, self.debugger.uarg2, self.debugger.uarg3, self.debugger.uarg4, self.debugger.uarg5, ], ) def _cpu_time_exceeded(self): log.warning('SIGXCPU in process %d', self.pid) self._is_tle = True def _run_process(self): try: self._spawn(self._executable, self._args, self._env, self._chdir, self._fds) except: # noqa: E722, need to catch absolutely everything self._spawn_error = sys.exc_info()[0] self._died.set() return finally: if self.stdin_needs_close: os.close(self._child_stdin) if self.stdout_needs_close: os.close(self._child_stdout) if self.stderr_needs_close: os.close(self._child_stderr) self._spawned_or_errored.set() # Adjust OOM score on the child process, sacrificing it before the judge process. try: oom_score_adj(OOM_SCORE_ADJ_MAX, self.pid) except Exception: import traceback traceback.print_exc() # TODO(tbrindus): this code should be the same as [self.returncode], so it shouldn't be duplicated code = self._monitor() if self._time and self.execution_time > self._time: self._is_tle = True self._died.set() return code def _shocker_thread(self): # On Linux, ignored signals still cause a notification under ptrace. # Hence, we use SIGWINCH, harmless and ignored signal to make wait4 return # pt_process::monitor, causing time to be updated. # On FreeBSD, a signal must not be ignored in order for wait4 to return. # Hence, we swallow SIGSTOP, which should never be used anyway, and use it # force an update. wake_signal = signal.SIGSTOP if 'freebsd' in sys.platform else signal.SIGWINCH self._spawned_or_errored.wait() while not self._died.wait(1): if self.execution_time > self._time or self.wall_clock_time > self._wall_time: log.warning('Shocker activated and killed %d', self.pid) self.kill() self._is_tle = True break try: os.killpg(self.pid, wake_signal) except OSError: pass def __init_streams(self, stdin, stdout, stderr): self.stdin = self.stdout = self.stderr = None self.stdin_needs_close = self.stdout_needs_close = self.stderr_needs_close = False if stdin == PIPE: self._child_stdin, self._stdin = os.pipe() self.stdin = os.fdopen(self._stdin, 'wb') self.stdin_needs_close = True elif isinstance(stdin, int): self._child_stdin, self._stdin = stdin, -1 elif stdin is not None: self._child_stdin, self._stdin = stdin.fileno(), -1 else: self._child_stdin = self._stdin = -1 if stdout == PIPE: self._stdout, self._child_stdout = os.pipe() self.stdout = os.fdopen(self._stdout, 'rb') self.stdout_needs_close = True elif isinstance(stdout, int): self._stdout, self._child_stdout = -1, stdout elif stdout is not None: self._stdout, self._child_stdout = -1, stdout.fileno() else: self._stdout = self._child_stdout = -1 if stderr == PIPE: self._stderr, self._child_stderr = os.pipe() self.stderr = os.fdopen(self._stderr, 'rb') self.stderr_needs_close = True elif isinstance(stderr, int): self._stderr, self._child_stderr = -1, stderr elif stderr is not None: self._stderr, self._child_stderr = -1, stderr.fileno() else: self._stderr = self._child_stderr = -1 communicate = _safe_communicate def unsafe_communicate(self, input=None): return _safe_communicate(self, input=input, outlimit=sys.maxsize, errlimit=sys.maxsize) def can_debug(arch): return (INTERPRETER_ARCH, arch) in _arch_map