# Copyright 2016 Quora, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. __doc__ = """ Various small helper classes and routines. """ import inspect import six import warnings from . import inspectable_class # export it here for backward compatibility from .disallow_inheritance import DisallowInheritance empty_tuple = () empty_list = [] empty_dict = {} globals()["empty_tuple"] = empty_tuple globals()["empty_list"] = empty_list globals()["empty_dict"] = empty_dict def true_fn(): return True def false_fn(): return False class MarkerObject(object): """Replaces None in cases when None value is also expected. Used mainly by caches to describe a cache miss. """ def __init__(self, name): if isinstance(name, six.binary_type): if six.PY2: warnings.warnpy3k( "MarkerObject does not support bytes names in Python 3" ) name = name.decode("utf-8") else: raise TypeError("name must be str, not bytes") self.name = name if six.PY2: def __str__(self): return unicode(self).encode("utf-8") def __unicode__(self): return self.name else: def __str__(self): return self.name def __repr__(self): return self.name none = MarkerObject(u"none") miss = MarkerObject(u"miss") same = MarkerObject(u"same") unspecified = MarkerObject(u"unspecified") globals()["none"] = none globals()["miss"] = miss globals()["same"] = same globals()["unspecified"] = unspecified class EmptyContext(object): def __enter__(self): pass def __exit__(self, exc_type, exc_val, exc_tb): pass def __repr__(self): return "qcore.empty_context" empty_context = EmptyContext() globals()["empty_context"] = empty_context class CythonCachedHashWrapper(object): def __init__(self, value): self._value = value self._hash = hash(value) def value(self): return self._value def hash(self): return self._hash def __call__(self): return self._value def __hash__(self): return self._hash def __richcmp__(self, other, op): # Cython way of implementing comparison operations if op == 2: return ( self() == other() if isinstance(other, CachedHashWrapper) else self() == other ) elif op == 3: return not ( self() == other() if isinstance(other, CachedHashWrapper) else self() == other ) else: raise NotImplementedError("only == and != are supported") def __repr__(self): return "%s(%r)" % (self.__class__.__name__, self._value) CachedHashWrapper = CythonCachedHashWrapper globals()["CachedHashWrapper"] = CythonCachedHashWrapper if hasattr(CythonCachedHashWrapper, "__richcmp__"): # This isn't Cython, so we must add eq and ne to make it work w/o Cython class PythonCachedHashWrapper(CachedHashWrapper): def __eq__(self, other): return ( self._value == other._value if isinstance(other, CachedHashWrapper) else self._value == other ) def __ne__(self, other): return not ( self._value == other._value if isinstance(other, CachedHashWrapper) else self._value == other ) # needed in Python 3 because this class overrides __eq__ def __hash__(self): return self._hash CachedHashWrapper = PythonCachedHashWrapper globals()["CachedHashWrapper"] = PythonCachedHashWrapper class ScopedValue(object): def __init__(self, default): self._value = default def get(self): return self._value def set(self, value): self._value = value def override(self, value): """Temporarily overrides the old value with the new one.""" if self._value is not value: return _ScopedValueOverrideContext(self, value) else: return empty_context def __call__(self): """Same as get.""" return self._value def __str__(self): return "ScopedValue(%s)" % (self._value,) def __repr__(self): return "ScopedValue(%r)" % (self._value,) class _ScopedValueOverrideContext(object): def __init__(self, target, value): self._target = target self._value = value self._old_value = None def __enter__(self): self._old_value = self._target._value self._target._value = self._value def __exit__(self, exc_type, exc_value, tb): self._target._value = self._old_value class _PropertyOverrideContext(object): def __init__(self, target, property_name, value): self._target = target self._property_name = property_name self._value = value self._old_value = None def __enter__(self): self._old_value = getattr(self._target, self._property_name) setattr(self._target, self._property_name, self._value) def __exit__(self, exc_type, exc_value, tb): setattr(self._target, self._property_name, self._old_value) override = _PropertyOverrideContext globals()["override"] = override def ellipsis(source, max_length): """Truncates a string to be at most max_length long.""" if max_length == 0 or len(source) <= max_length: return source return source[: max(0, max_length - 3)] + "..." def safe_str(source, max_length=0): """Wrapper for str() that catches exceptions.""" try: return ellipsis(str(source), max_length) except Exception as e: return ellipsis("<n/a: str(...) raised %s>" % e, max_length) def safe_repr(source, max_length=0): """Wrapper for repr() that catches exceptions.""" try: return ellipsis(repr(source), max_length) except Exception as e: return ellipsis("<n/a: repr(...) raised %s>" % e, max_length) def dict_to_object(source): """Returns an object with the key-value pairs in source as attributes.""" target = inspectable_class.InspectableClass() for k, v in source.items(): setattr(target, k, v) return target def copy_public_attrs(source_obj, dest_obj): """Shallow copies all public attributes from source_obj to dest_obj. Overwrites them if they already exist. """ for name, value in inspect.getmembers(source_obj): if not any(name.startswith(x) for x in ["_", "func", "im"]): setattr(dest_obj, name, value) def object_from_string(name): """Creates a Python class or function from its fully qualified name. :param name: A fully qualified name of a class or a function. In Python 3 this is only allowed to be of text type (unicode). In Python 2, both bytes and unicode are allowed. :return: A function or class object. This method is used by serialization code to create a function or class from a fully qualified name. """ if six.PY3: if not isinstance(name, str): raise TypeError("name must be str, not %r" % type(name)) else: if isinstance(name, unicode): name = name.encode("ascii") if not isinstance(name, (str, unicode)): raise TypeError("name must be bytes or unicode, got %r" % type(name)) pos = name.rfind(".") if pos < 0: raise ValueError("Invalid function or class name %s" % name) module_name = name[:pos] func_name = name[pos + 1 :] try: mod = __import__(module_name, fromlist=[func_name], level=0) except ImportError: # Hail mary. if the from import doesn't work, then just import the top level module # and do getattr on it, one level at a time. This will handle cases where imports are # done like `from . import submodule as another_name` parts = name.split(".") mod = __import__(parts[0], level=0) for i in range(1, len(parts)): mod = getattr(mod, parts[i]) return mod else: return getattr(mod, func_name) def catchable_exceptions(exceptions): """Returns True if exceptions can be caught in the except clause. The exception can be caught if it is an Exception type or a tuple of exception types. """ if isinstance(exceptions, type) and issubclass(exceptions, BaseException): return True if ( isinstance(exceptions, tuple) and exceptions and all(issubclass(it, BaseException) for it in exceptions) ): return True return False