# -*- coding: UTF-8 -*- # Copyright (C) 2019 Parrot Drones SAS # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions # are met: # * Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # * Redistributions in binary form must reproduce the above copyright # notice, this list of conditions and the following disclaimer in # the documentation and/or other materials provided with the # distribution. # * Neither the name of the Parrot Company nor the names # of its contributors may be used to endorse or promote products # derived from this software without specific prior written # permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS # FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE # PARROT COMPANY BE LIABLE FOR ANY DIRECT, INDIRECT, # INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, # BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS # OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED # AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT # OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF # SUCH DAMAGE. from __future__ import unicode_literals from __future__ import absolute_import import functools import pprint import time from abc import ABC, abstractmethod from aenum import Enum from boltons.setutils import IndexedSet from concurrent.futures import Future, as_completed from concurrent.futures import TimeoutError as FutureTimeoutError from concurrent.futures import CancelledError as FutureCancelledError from collections import OrderedDict, deque from logging import getLogger from olympe._private import ( callback_decorator, merge_mapping, timestamp_now, equals, DEFAULT_FLOAT_TOL, ) from olympe._private.pomp_loop_thread import PompLoopThread from olympe.arsdkng.events import EventContext, MultipleEventContext, ArsdkMessageEvent from olympe.arsdkng.event_marker import EventMarker from olympe.arsdkng.listener import Subscriber import threading class AbstractScheduler(ABC): __slots__ = () @abstractmethod def schedule(self, expectations, **kwds): pass @abstractmethod def subscribe( self, callback, expectation=None, queue_size=None, default=None, timeout=None ): """ Subscribe a callback to some specific event expectation or to all events if no specific event expectation is given in parameter. :param callback: a callable object (function, method, function-like object, ...) :param expectation: an event expectation object (ex: `FlyingStateChanged()`) :param queue_size: this subscriber queue size or None if unbounded (the default) :type queue_size: int :param timeout: the callback timeout in seconds or None for infinite timeout (the default) :rtype: Subscriber """ @abstractmethod def unsubscribe(self, subscriber): """ Unsubscribe a previously registered subscriber :param subscriber: the subscriber previously returned by :py:func:`~olympe.Drone.subscribe` :type subscriber: Subscriber """ class Namespace: pass class DefaultScheduler(AbstractScheduler): __slots__ = "_attr" def __init__(self, pomp_loop_thread, name=None, device_name=None): self._attr = Namespace() self._attr.default = Namespace() self._attr.default.name = name self._attr.default.device_name = device_name if self._attr.default.name is not None: self._attr.default.logger = getLogger( "olympe.{}.scheduler".format(self._attr.default.name) ) elif self._attr.default.device_name is not None: self._attr.default.logger = getLogger( "olympe.scheduler.{}".format(self._attr.default.device_name) ) else: self._attr.default.logger = getLogger("olympe.scheduler") # Expectations internal state self._attr.default.contexts = OrderedDict() self._attr.default.pending_expectations = [] self._attr.default.pomp_loop_thread = pomp_loop_thread # Setup expectations monitoring timer, this is used to detect timedout # expectations periodically self._attr.default.expectations_timer = self._attr.default.pomp_loop_thread.create_timer( lambda timer, userdata: self._garbage_collect() ) if not self._attr.default.pomp_loop_thread.set_timer( self._attr.default.expectations_timer, delay=200, period=15 ): error_message = "Unable to launch piloting interface" self._attr.default.logger.error(error_message) raise RuntimeError(error_message) # Subscribers internal state self._attr.default.subscribers_lock = threading.Lock() self._attr.default.subscribers = [] self._attr.default.running_subscribers = OrderedDict() self._attr.default.subscribers_thread_loop = PompLoopThread( self._attr.default.logger ) self._attr.default.subscribers_thread_loop.start() def add_context(self, name, context): self._attr.default.contexts[name] = context def remove_context(self, name): return self._attr.default.contexts.pop(name, None) is not None def context(self, name): return self._attr.default.contexts[name] def schedule(self, expectations, **kwds): # IMPORTANT note: the schedule method should ideally be called from # this scheduler pomp loop thread. This method should not be blocking # on any input/output (this is true if all expectations.check/_schedule # method are non-blocking). # To ensure that `self._schedule()` is called in the right thread we # execute it through the pomp loop run_async function. If we are already # in the pomp loop thread, this `self._schedule()` is called # synchronously self._attr.default.pomp_loop_thread.run_async( self._schedule, expectations, **kwds ).result() def run(self, *args, **kwds): return self._attr.default.pomp_loop_thread.run_async( *args, **kwds ) @callback_decorator() def _schedule(self, expectation, **kwds): expectation._schedule(self) monitor = kwds.get("monitor", True) if monitor and not expectation.success(): self._attr.default.pending_expectations.append(expectation) def process_event(self, event): self._attr.default.pomp_loop_thread.run_async(self._process_event, event) @callback_decorator() def _process_event(self, event): # For all current pending expectations garbage_collected_expectations = [] for expectation in self._attr.default.pending_expectations: if expectation.cancelled() or expectation.timedout(): # Garbage collect canceled/timedout expectations garbage_collected_expectations.append(expectation) elif expectation.check(event).success(): # If an expectation successfully matched a message, signal the expectation # and remove it from the currently monitored expectations. expectation.set_success() garbage_collected_expectations.append(expectation) # Remove the garbage collected expectations for expectation in garbage_collected_expectations: self._attr.default.pending_expectations.remove(expectation) # Notify subscribers self._attr.default.pomp_loop_thread.run_later(self._notify_subscribers, event) @callback_decorator() def _garbage_collect(self): # For all currently pending expectations garbage_collected_expectations = [] for expectation in self._attr.default.pending_expectations: # Collect cancelled or timedout expectation # The actual cancel/timeout check is delegated to the expectation if expectation.cancelled() or expectation.timedout(): garbage_collected_expectations.append(expectation) # Remove the collected expectations for expectation in garbage_collected_expectations: self._attr.default.pending_expectations.remove(expectation) def stop(self): for expectation in self._attr.default.pending_expectations: expectation.cancel() self._attr.default.pending_expectations = [] def destroy(self): self.stop() self._attr.default.subscribers_thread_loop.stop() @callback_decorator() def _notify_subscribers(self, event): with self._attr.default.subscribers_lock: defaults = OrderedDict.fromkeys( ( s._default for s in self._attr.default.subscribers if s._default is not None ) ) for subscriber in self._attr.default.subscribers: checked = subscriber.notify(event) if checked: if subscriber._default is not None: defaults.pop(subscriber._default, None) future = self._attr.default.subscribers_thread_loop.run_async( subscriber.process ) self._attr.default.running_subscribers[id(subscriber)] = future future.add_done_callback( functools.partial( lambda subscriber, _: self._attr.default.running_subscribers.pop( id(subscriber) ), subscriber, ) ) for default in defaults: default.notify(event) self._attr.default.subscribers_thread_loop.run_async(default.process) def subscribe( self, callback, expectation=None, queue_size=Subscriber.default_queue_size, default=None, timeout=None, ): """ Subscribe a callback to some specific event expectation or to all events if no specific event expectation is given in parameter. :param callback: a callable object (function, method, function-like object, ...) :param expectation: an event expectation object (ex: `FlyingStateChanged()`) :param queue_size: this subscriber queue size or None if unbounded (the default) :type queue_size: int :param timeout: the callback timeout in seconds :rtype: Subscriber """ subscriber = Subscriber( self, callback, expectation=expectation, queue_size=queue_size, default=default, timeout=timeout, ) with self._attr.default.subscribers_lock: self._attr.default.subscribers.append(subscriber) return subscriber def unsubscribe(self, subscriber): """ Unsubscribe a previously registered subscriber :param subscriber: the subscriber previously returned by :py:func:`~olympe.Drone.subscribe` :type subscriber: Subscriber """ with self._attr.default.subscribers_lock: future = self._attr.default.running_subscribers.pop(id(subscriber), None) if future is not None: try: future.result(subscriber.timeout) except Exception as e: self._attr.default.logger.exception(e) self._attr.default.subscribers.remove(subscriber) def _subscriber_overrun(self, subscriber, event): self._attr.default.logger.warning( "Subscriber {} event queue ({}) is overrun by {}".format( subscriber, subscriber.queue_size, event ) ) class StreamSchedulerMixin: """ StreamScheduler is scheduler class decorator (decorator pattern not an actual python decorator) that acts as a queuing discipline to limit the maximum number of parallelized expectation processing. """ __slots__ = () def __init__(self, *args, stream_timeout=None, max_parallel_processing=1, **kwds): """ :param scheduler: the decorated scheduler :param stream_timeout: the default timeout value in seconds used by StreamScheduler.join :param max_parallel_processing: the maximum number of parallelized expectation processing (defaults to 1) """ queue_size = 1024 self._attr.stream_scheduler = Namespace() self._attr.stream_scheduler.timeout = stream_timeout self._attr.stream_scheduler.max_parallel_processing = max_parallel_processing self._attr.stream_scheduler.token_count = threading.BoundedSemaphore( max_parallel_processing ) self._attr.stream_scheduler.expectation_queue = deque([], queue_size) self._attr.stream_scheduler.pending_expectations = set() self._attr.stream_scheduler.on_done_condition = threading.Condition() @callback_decorator() def _schedule(self, expectation, **kwds): """ Schedule one expectation processing if the maximum number of parallel processing has not been reached yet. Otherwise, the expectation will remain in an internal pending queue until at least one expectation processing is done. """ self._attr.stream_scheduler.expectation_queue.append((expectation, kwds)) self._stream_schedule() def _stream_schedule(self): # try to schedule expectations from the queue if possible # while at least one token in available while self._attr.stream_scheduler.expectation_queue and ( self._attr.stream_scheduler.token_count.acquire(blocking=False) ): expectation, kwds = self._attr.stream_scheduler.expectation_queue.popleft() self._attr.stream_scheduler.pending_expectations.add(expectation) expectation.add_done_callback(self._stream_on_done) super()._schedule(expectation, **kwds) def _stream_on_done(self, expectation): # release one token self._attr.stream_scheduler.token_count.release() self._attr.stream_scheduler.pending_expectations.remove(expectation) # try to schedule one expectation self._stream_schedule() # notify that we're done with one expectation processing with self._attr.stream_scheduler.on_done_condition: self._attr.stream_scheduler.on_done_condition.notify_all() def stream_join(self, timeout=None): """ Wait for all currently pending expectations """ if timeout is None: timeout = self._attr.stream_scheduler.timeout with self._attr.stream_scheduler.on_done_condition: self._attr.stream_scheduler.on_done_condition.wait_for( lambda: ( not bool(self._attr.stream_scheduler.pending_expectations) and not bool(self._attr.stream_scheduler.expectation_queue) ), timeout=timeout, ) class SchedulerDecoratorContext: def __init__(self, decorated): self._decorated = decorated def __getattr__(self, name): return getattr(self._decorated, name) def decorate(self, name, decorator, *args, **kwds): if issubclass(self._decorated.__class__, decorator): # We've already applied this decorator, nothing to be done return namespace = dict(decorator.__dict__) self._decorated.__class__ = type( name, (decorator, type(self._decorated)), namespace ) decorator.__init__(self._decorated, *args, **kwds) class Scheduler(SchedulerDecoratorContext): def __init__(self, *args, **kwds): super().__init__(DefaultScheduler(*args, **kwds)) class ExpectPolicy(Enum): wait, check, check_wait = range(3) class ExpectationBase(ABC): always_monitor = False def __init__(self): self._future = Future() self._awaited = False self._scheduler = None self._success = False self._timeout = None self._deadline = None self._timedout = False # FIXME: float_tol should be moved to ArsdkExpectationBase self._float_tol = DEFAULT_FLOAT_TOL def _schedule(self, scheduler): # This expectation is scheduled on the `scheduler`, subclasses of ExpectationBase can # perform some operations on this scheduler: schedule another expectation later or # perform an operation on the scheduler object when this expectation is schedule (like # sending a message for which this expectation object expect some result). # IMPORTANT NOTE: this function (or its overridden versions) should be non-blocking self._awaited = True self._scheduler = scheduler if self._timeout is not None: self._deadline = timestamp_now() + self._timeout def success(self): return self._success def wait(self, _timeout=None): if self._awaited: try: self._future.result(timeout=_timeout) except FutureTimeoutError: self.set_timedout() except FutureCancelledError: self.cancel() return self def add_done_callback(self, cb): self._future.add_done_callback(lambda f: cb(self)) def set_success(self): if not self._future.done(): self._success = True self._future.set_result(self.received_events()) return True return False def set_exception(self, exception): if not self._future.done(): self._future.set_exception(exception) def set_timeout(self, _timeout): self._timeout = _timeout def set_timedout(self): if self._future.done(): return False if not self._success: self._timedout = True self.cancel() return True return False def cancel(self): if self._future.done(): return False self._future.cancel() return True def cancelled(self): return self._future.cancelled() def timedout(self): if self._timedout: return True if self._success: return False if self._deadline is not None: timedout = timestamp_now() > self._deadline if timedout: self.set_timedout() return self._timedout def set_float_tol(self, _float_tol): self._float_tol = _float_tol def base_copy(self, *args, **kwds): other = self.__class__(*args, **kwds) ExpectationBase.__init__(other) other._timeout = self._timeout other._float_tol = self._float_tol return other @abstractmethod def copy(self): """ All expectations sublclasses must implement a shallow copy. """ pass def done(self): return (self._future.done() or not self._awaited) and self._success def __bool__(self): return self.done() def __or__(self, other): return WhenAnyExpectation([self, other]) def __and__(self, other): return WhenAllExpectations([self, other]) def __rshift__(self, other): return WhenSequenceExpectations([self, other]) __nonzero__ = __bool__ class SuccessExpectation(ExpectationBase): def __init__(self): super().__init__() self.set_success() def copy(self): return super().base_copy() def received_events(self): return None class FailedExpectation(ExpectationBase): def __init__(self, message): super().__init__() self._message = message self.set_exception(RuntimeError(message)) def copy(self): return super().base_copy(self._message) def explain(self): return self._message class FutureExpectation(ExpectationBase): def __init__(self, future, status_checker=lambda status: True): super().__init__() self._future = future self._status_checker = status_checker self._future.add_done_callback(self._on_done) def _on_done(self, f): if f.exception() is None: self._success = self._status_checker(f.result()) def check(self, *args, **kwds): return self def copy(self): return super().base_copy(self._future, self._status_checker) class Expectation(ExpectationBase): @abstractmethod def check(self, *args, **kwds): # IMPORTANT NOTE: this function (or its overridden versions) should be non-blocking pass @abstractmethod def expected_events(self): pass @abstractmethod def received_events(self): """ Returns a collection of events that have matched at least one of the messages ID monitored by this expectation. """ pass @abstractmethod def matched_events(self): """ Returns a collection of events that have matched this expectation (or a child expectation) """ pass @abstractmethod def unmatched_events(self): """ Returns a collection of events object that are still expected """ pass def marked_events(self, default_marked_events=EventMarker.unmatched): """ Returns a collection of events with matched/unmatched markers. """ if self._success: return self.expected_events()._set_marker(EventMarker.matched) else: return self.expected_events()._set_marker(default_marked_events) def explain(self): """ Returns a debug string that explain this expectation current state. """ try: return str(self.marked_events()) except Exception: getLogger("olympe.expectations").exception("") return None class ArsdkExpectationBase(Expectation): def __init__(self): super().__init__() self._deprecated_statedict = False def _set_deprecated_statedict(self): self._deprecated_statedict = True @abstractmethod def _fill_default_arguments(self, message, args): pass @abstractmethod def check(self, received_event, *args, **kwds): pass class ArsdkFillDefaultArgsExpectationMixin(object): def _fill_default_arguments(self, message, args): for argname, argval in self.expected_args.copy().items(): if callable(argval): # command message expectation args mapping self.expected_args[argname] = argval(message, args) elif argval is None and argname in args: # default argument handling self.expected_args[argname] = args[argname] class ArsdkEventExpectation(ArsdkFillDefaultArgsExpectationMixin, ArsdkExpectationBase): def __init__(self, expected_message, expected_args): super().__init__() self.expected_message = expected_message.new() self.expected_args = OrderedDict() for k, v in expected_args.items(): self.expected_args[k] = v self.received_args = [] self._received_events = [] self.matched_args = OrderedDict() def copy(self): return super().base_copy( self.expected_message.copy(), self.expected_args.copy() ) def check(self, received_event, *args, **kwds): if not isinstance(received_event, ArsdkMessageEvent): return self if received_event.message.id != self.expected_message.id: return self self._received_events.append(received_event) self.received_args.append(received_event.args) for arg_name, arg_val in self.expected_args.items(): if arg_val is None: continue if arg_name not in received_event.args: return self if not equals( received_event.args[arg_name], self.expected_args[arg_name], float_tol=self._float_tol, ): return self if not self._success: self.matched_args = received_event.args.copy() self.set_success() return self def expected_events(self): if not self._deprecated_statedict: return EventContext( [ArsdkMessageEvent(self.expected_message, self.expected_args)] ) else: return { self.expected_message.FULL_NAME: { k.upper(): v for k, v in self.expected_args.items() } } def received_events(self): if not self._deprecated_statedict: if not self._received_events: return EventContext() return EventContext(self._received_events[:]) else: return { self.expected_message.FULL_NAME: [ {k.upper(): v for k, v in args.items()} for args in self.received_args ] } def matched_events(self): if not self._deprecated_statedict: if self._success: if not self.matched_args: return EventContext() return EventContext( [ArsdkMessageEvent(self.expected_message, self.matched_args)] ) else: return EventContext() else: if self._success: return { self.expected_message.FULL_NAME: { k.upper(): v for k, v in self.matched_args.items() } } else: return {} def unmatched_events(self): if not self._deprecated_statedict: if not self._success: return EventContext(self.expected_events().events()) else: return EventContext() else: if not self._success: return self.expected_events() else: return {} def marked_events(self, default_marked_events=EventMarker.unmatched): if not self._deprecated_statedict: return super().marked_events(default_marked_events=default_marked_events) else: if not self._success: return self.expected_events() else: return {} @classmethod def from_arsdk(cls, messages, ar_expectation): expected_message = messages.by_id_name[ar_expectation.id.lstrip("#")] # When a list item is expected without arguments # expect the last and/or empty element if not ar_expectation.arguments and ( expected_message._is_list_item() or expected_message._is_map_item() ): expectations = [] for event in ("Last", "Empty"): args = OrderedDict() event = expected_message.args_bitfield["list_flags"](event) args["list_flags"] = event expectations.append(cls(expected_message, args)) return ArsdkWhenAnyExpectation(expectations) args = OrderedDict() for arg in ar_expectation.arguments: if arg.value.startswith("this."): argname = arg.value[5:] args[arg.name] = ( lambda argname: lambda command_message, command_args: command_args[ argname ] )(argname) elif arg.name in expected_message.args_enum: args[arg.name] = expected_message.args_enum[arg.name][arg.value] elif arg.name in expected_message.args_bitfield: args[arg.name] = expected_message.args_bitfield[arg.name](arg.value) else: args[arg.name] = int(arg.value) return cls(expected_message, args) def __iter__(self): return iter((self,)) def __len__(self): return 1 def __repr__(self): return pprint.pformat({self.expected_message.FullName: self.expected_args}) class ArsdkCheckStateExpectation( ArsdkFillDefaultArgsExpectationMixin, ArsdkExpectationBase ): def __init__(self, expected_message, expected_args): super().__init__() self.expected_message = expected_message.new() self.expected_args = expected_args self.matched_state = None def copy(self): return super().base_copy( self.expected_message.copy(), self.expected_args.copy() ) def check(self, received_event, *args, **kwds): return self def _schedule(self, scheduler): super()._schedule(scheduler) controller = scheduler.context("olympe.controller") try: if controller.check_state( self.expected_message, _float_tol=self._float_tol, **self.expected_args ): self.matched_state = controller.get_state(self.expected_message) self.set_success() else: self.cancel() except KeyError: # state not found pass def expected_events(self): if not self._deprecated_statedict: return EventContext( [ArsdkMessageEvent(self.expected_message, self.expected_args)] ) else: return { self.expected_message.FULL_NAME: { k.upper(): v for k, v in self.expected_args.items() } } def received_events(self): return EventContext() if not self._deprecated_statedict else {} def matched_events(self): if not self._deprecated_statedict: if self._success: if not self.matched_state: return EventContext() return EventContext( [ ArsdkMessageEvent( self.expected_message, self.matched_state, ExpectPolicy.check, ) ] ) else: return EventContext() else: if self._success: return { self.expected_message.FULL_NAME: { k.upper(): v for k, v in self.matched_state.items() } } else: return {} def unmatched_events(self): if not self._deprecated_statedict: if not self._success: return EventContext( [ArsdkMessageEvent(self.expected_message, self.matched_state)] ) else: return EventContext() else: if not self._success: return { self.expected_message.FULL_NAME: { k.upper(): v for k, v in self.matched_state.items() } } else: return {} def __iter__(self): return iter((self,)) def __len__(self): return 1 def __repr__(self): return pprint.pformat({self.expected_message.FullName: self.expected_args}) class CheckWaitStateExpectationMixin: def __init__(self, check_expectation, wait_expectation): super().__init__() self._check_expectation = check_expectation self._wait_expectation = wait_expectation self._checked = False def _schedule(self, scheduler): super()._schedule(scheduler) self._check_expectation._schedule(scheduler) self._checked = self._check_expectation.success() self._success = self._checked if not self._success: scheduler._schedule( self._wait_expectation, monitor=self._wait_expectation.always_monitor ) else: self.set_success() def copy(self): other = super().base_copy( self._check_expectation.copy(), self._wait_expectation.copy() ) return other def check(self, *args, **kwds): if not self._checked and self._wait_expectation.check(*args, **kwds).success(): self.set_success() return self def expected_events(self): if self._checked: return EventContext( self._check_expectation.expected_events().events(), ExpectPolicy.check_wait, ) else: return EventContext( self._wait_expectation.expected_events().events(), ExpectPolicy.check_wait, ) def received_events(self): if self._checked: return self._check_expectation.received_events() else: return self._wait_expectation.received_events() def matched_events(self): if self._checked: return EventContext(self._check_expectation.matched_events().events()) else: return EventContext(self._wait_expectation.matched_events().events()) def unmatched_events(self): if self._checked: return EventContext(self._check_expectation.unmatched_events().events()) else: return EventContext(self._wait_expectation.unmatched_events().events()) def set_timeout(self, _timeout): super().set_timeout(_timeout) self._wait_expectation.set_timeout(_timeout) def timedout(self): if self._checked: return False else: if self._wait_expectation.timedout(): self.set_timedout() return self._wait_expectation.timedout() def cancelled(self): return self._wait_expectation.cancelled() class CheckWaitStateExpectation(CheckWaitStateExpectationMixin, Expectation): pass class ArsdkCheckWaitStateExpectation( CheckWaitStateExpectationMixin, ArsdkExpectationBase ): def _set_deprecated_statedict(self): super()._set_deprecated_statedict() if hasattr(self._check_expectation, "_set_deprecated_statedict"): self._check_expectation._set_deprecated_statedict() if hasattr(self._wait_expectation, "_set_deprecated_statedict"): self._wait_expectation._set_deprecated_statedict() def _fill_default_arguments(self, message, args): if hasattr(self._check_expectation, "_fill_default_arguments"): self._check_expectation._fill_default_arguments(message, args) if hasattr(self._wait_expectation, "_fill_default_arguments"): self._wait_expectation._fill_default_arguments(message, args) def set_float_tol(self, _float_tol): super().set_float_tol(_float_tol) self._check_expectation.set_float_tol(_float_tol) self._wait_expectation.set_float_tol(_float_tol) class MultipleExpectationMixin: def __init__(self, expectations=None): super().__init__() if expectations is None: self.expectations = [] else: self.expectations = expectations self.matched_expectations = IndexedSet() def copy(self): other = super().base_copy(list(map(lambda e: e.copy(), self.expectations))) return other def append(self, expectation): if not isinstance(expectation, self.__class__): self.expectations.append(expectation) else: self.expectations.extend(expectation.expectations) return self def expected_events(self): return MultipleEventContext( list(map(lambda e: e.expected_events(), self.expectations)), self._combine_method(), ) def received_events(self): return MultipleEventContext( list(map(lambda e: e.received_events(), self.expectations)), self._combine_method(), ) def matched_events(self): return MultipleEventContext( list(map(lambda e: e.matched_events(), self.matched_expectations)), self._combine_method(), ) def unmatched_events(self): return MultipleEventContext( list(map(lambda e: e.unmatched_events(), self.unmatched_expectations())), self._combine_method(), ) def unmatched_expectations(self): for expectation in self.expectations: if expectation not in self.matched_expectations: yield expectation def __iter__(self): return iter(self.expectations) def __len__(self): return len(self.expectations) def __repr__(self): return "<{}: {}>".format(self.__class__.__name__, repr(self.expectations)) @abstractmethod def _combine_method(self): pass def marked_events(self, default_marked_events=EventMarker.unmatched): if self._success: default_marked_events = EventMarker.ignored return MultipleEventContext( list( map(lambda e: e.marked_events(default_marked_events), self.expectations) ), self._combine_method(), ) def as_completed(self, timeout=None): end_time = None if timeout is not None: end_time = timeout + time.monotonic() done = set() while end_time is None or end_time > time.monotonic(): fs = OrderedDict([(e._future, e) for e in self.expectations if e not in done]) for f in as_completed(fs.keys(), timeout=timeout): yield fs[f] done.add(fs[f]) if len(done) == len(self.expectations): break class MultipleExpectation(MultipleExpectationMixin, Expectation): pass class ArsdkMultipleExpectation(MultipleExpectationMixin, ArsdkExpectationBase): def _set_deprecated_statedict(self): super()._set_deprecated_statedict() for expectation in self.expectations: if hasattr(expectation, "_set_deprecated_statedict"): expectation._set_deprecated_statedict() def _fill_default_arguments(self, message, args): for expectation in self.expectations: if hasattr(expectation, "_fill_default_arguments"): expectation._fill_default_arguments(message, args) def set_float_tol(self, _float_tol): for expectation in self.expectations: expectation.set_float_tol(_float_tol) @classmethod def from_arsdk(cls, messages, ar_expectations): expectations = list( map( lambda e: ArsdkEventExpectation.from_arsdk(messages, e), ar_expectations ) ) return cls(expectations) def expected_events(self): if not self._deprecated_statedict: return super().expected_events() else: return merge_mapping(map(lambda e: e.expected_events(), self.expectations)) def received_events(self): if not self._deprecated_statedict: return super().received_events() else: return merge_mapping(map(lambda e: e.received_events(), self.expectations)) def matched_events(self): if not self._deprecated_statedict: return super().matched_events() else: return merge_mapping( map(lambda e: e.matched_events(), self.matched_expectations) ) def unmatched_events(self): if not self._deprecated_statedict: return super().unmatched_events() else: return merge_mapping( map(lambda e: e.unmatched_events(), self.unmatched_expectations()) ) def marked_events(self, default_marked_events=EventMarker.unmatched): if not self._deprecated_statedict: return super().marked_events(default_marked_events=EventMarker.unmatched) else: if not self._success: return self.expected_events() else: return {} class WhenAnyExpectationMixin: def _schedule(self, scheduler): super()._schedule(scheduler) for expectation in self.expectations: scheduler._schedule(expectation, monitor=expectation.always_monitor) if expectation.success(): self.matched_expectations.add(expectation) self.set_success() break if self.success(): return if all(expectation.cancelled() for expectation in self.expectations): self.cancel() def timedout(self): if super().timedout(): return True elif all(map(lambda e: e.timedout(), self.expectations)): self.set_timedout() return super().timedout() def cancelled(self): if super().cancelled(): return True elif all(map(lambda e: e.cancelled(), self.expectations)): self.cancel() return True else: return False def check(self, *args, **kwds): for expectation in self.expectations: if ( expectation.always_monitor or not expectation.success() ) and expectation.check(*args, **kwds).success(): self.matched_expectations.add(expectation) self.set_success() return self return self def __or__(self, other): return self.append(other) def _combine_method(self): return "|" class WhenAnyExpectation(WhenAnyExpectationMixin, MultipleExpectation): pass class ArsdkWhenAnyExpectation(WhenAnyExpectationMixin, ArsdkMultipleExpectation): pass class WhenAllExpectationsMixin: def _schedule(self, scheduler): super()._schedule(scheduler) for expectation in self.expectations: scheduler._schedule(expectation, monitor=expectation.always_monitor) if expectation.success(): self.matched_expectations.add(expectation) if len(self.expectations) == len(self.matched_expectations): self.set_success() elif any(expectation.cancelled() for expectation in self.expectations): self.cancel() def timedout(self): if super().timedout(): return True elif any(map(lambda e: e.timedout(), self.expectations)): self.set_timedout() return super().timedout() def cancelled(self): if super().cancelled(): return True elif any(map(lambda e: e.cancelled(), self.expectations)): self.cancel() return True else: return False def check(self, *args, **kwds): for expectation in self.expectations: if ( expectation.always_monitor or not expectation.success() ) and expectation.check(*args, **kwds).success(): self.matched_expectations.add(expectation) if len(self.expectations) == len(self.matched_expectations): self.set_success() return self def __and__(self, other): return self.append(other) def _combine_method(self): return "&" class WhenAllExpectations(WhenAllExpectationsMixin, MultipleExpectation): pass class ArsdkWhenAllExpectations(WhenAllExpectationsMixin, ArsdkMultipleExpectation): pass class ArsdkCommandExpectation(ArsdkMultipleExpectation): def __init__(self, command_message, command_args=None, expectations=None): super().__init__(expectations) self.command_message = command_message.new() self.command_args = command_args or [] self._command_future = None self._no_expect = False def timedout(self): if super().timedout(): return True elif any(map(lambda e: e.timedout(), self.expectations)): self.set_timedout() return super().timedout() def cancelled(self): if super().cancelled(): return True elif any(map(lambda e: e.cancelled(), self.expectations)): self.cancel() return True else: return False def check(self, received_event, *args, **kwds): if not isinstance(received_event, ArsdkMessageEvent): return self if self._command_future is None or ( not self._command_future.done() or not self._command_future.result() ): return self if self._no_expect: self.set_success() return self for expectation in self.expectations: if ( expectation.always_monitor or not expectation.success() ) and expectation.check(received_event).success(): self.matched_expectations.add(expectation) if len(self.expectations) == len(self.matched_expectations): self.set_success() return self def _fill_default_arguments(self, message, args): super()._fill_default_arguments(message, args) if self.command_message.id != message.id: raise RuntimeError( "Unexpected message {} where {} was expected".format( message.fullName, self.command_message.fullName ) ) self.command_args = list(args.values()) def copy(self): return super().base_copy( self.command_message.copy(), self.command_args[:], list(map(lambda e: e.copy(), self.expectations)), ) def _schedule(self, scheduler): if not self._awaited: for expectation in self.expectations: scheduler._schedule(expectation, monitor=expectation.always_monitor) controller = scheduler.context("olympe.controller") self._command_future = controller._send_command_raw( self.command_message, *self.command_args ) super()._schedule(scheduler) def no_expect(self, value): self._no_expect = value def _combine_method(self): return "&" def explain(self): if self._command_future is None: return "{} has not been sent yet".format(self.command_message.fullName) elif not self._command_future.done() or not self._command_future.result(): return "{} has been sent but hasn't been acknowledged".format( self.command_message.fullName ) else: ret = "{} has been sent and acknowledged.".format( self.command_message.fullName ) if not self._no_expect and self.expectations: ret += " Command expectations status :\n{}".format(super().explain()) return ret class WhenSequenceExpectationsMixin: def _schedule(self, scheduler): super()._schedule(scheduler) self._do_schedule() def _do_schedule(self): if self._scheduler is None: return # Schedule all available expectations in this sequence until we # encounter a pending asynchronous expectation while self._current_expectation() is not None: self._scheduler._schedule( self._current_expectation(), monitor=self._current_expectation().always_monitor, ) if not self._current_expectation().success(): break self.matched_expectations.add(self._current_expectation()) if len(self.expectations) == len(self.matched_expectations): self.set_success() elif any(expectation.cancelled() for expectation in self.expectations): self.cancel() def timedout(self): if super().timedout(): return True elif any(map(lambda e: e.timedout(), self._pending_expectations())): self.set_timedout() return super().timedout() def cancelled(self): if super().cancelled(): return True elif any(map(lambda e: e.cancelled(), self._pending_expectations())): self.cancel() return True else: return False def _current_expectation(self): return ( self.expectations[len(self.matched_expectations)] if len(self.matched_expectations) < len(self.expectations) else None ) def _pending_expectations(self): return ( self.expectations[len(self.matched_expectations):] if len(self.matched_expectations) < len(self.expectations) else [] ) def check(self, *args, **kwds): if self._current_expectation() is None: self.set_success() return self # While the current event matches an unmatched expectation # in this sequence while ( self._current_expectation() is not None and ( self._current_expectation().always_monitor or not self._current_expectation().success() ) and self._current_expectation().check(*args, **kwds).success() ): # Consume the current expectation self.matched_expectations.add(self._current_expectation()) # Schedule the next expectation(s), if any. # This may also consume one or more synchronous expectations # (i.e. events with policy="check"). self._do_schedule() if len(self.expectations) == len(self.matched_expectations): self.set_success() return self def __rshift__(self, other): return self.append(other) def _combine_method(self): return ">>" class WhenSequenceExpectations(WhenSequenceExpectationsMixin, MultipleExpectation): pass class ArsdkWhenSequenceExpectations( WhenSequenceExpectationsMixin, ArsdkMultipleExpectation ): pass