from __future__ import unicode_literals from six import iteritems, add_metaclass from six.moves import map from .exceptions import UnknownDslObject, ValidationException SKIP_VALUES = ('', None) def _wrap(val, obj_wrapper=None): if isinstance(val, dict): return AttrDict(val) if obj_wrapper is None else obj_wrapper(val) if isinstance(val, list): return AttrList(val) return val def _make_dsl_class(base, name, params_def=None, suffix=''): """ Generate a DSL class based on the name of the DSL object and it's parameters """ attrs = {'name': name} if params_def: attrs['_param_defs'] = params_def cls_name = str(''.join(s.title() for s in name.split('_')) + suffix) return type(cls_name, (base, ), attrs) class AttrList(object): def __init__(self, l, obj_wrapper=None): # make iteables into lists if not isinstance(l, list): l = list(l) self._l_ = l self._obj_wrapper = obj_wrapper def __repr__(self): return repr(self._l_) def __eq__(self, other): if isinstance(other, AttrList): return other._l_ == self._l_ # make sure we still equal to a dict with the same data return other == self._l_ def __getitem__(self, k): l = self._l_[k] if isinstance(k, slice): return AttrList(l) return _wrap(l, self._obj_wrapper) def __setitem__(self, k, value): self._l_[k] = value def __iter__(self): return map(lambda i: _wrap(i, self._obj_wrapper), self._l_) def __len__(self): return len(self._l_) def __nonzero__(self): return bool(self._l_) __bool__ = __nonzero__ def __getattr__(self, name): return getattr(self._l_, name) class AttrDict(object): """ Helper class to provide attribute like access (read and write) to dictionaries. Used to provide a convenient way to access both results and nested dsl dicts. """ def __init__(self, d): # assign the inner dict manually to prevent __setattr__ from firing super(AttrDict, self).__setattr__('_d_', d) def __contains__(self, key): return key in self._d_ def __nonzero__(self): return bool(self._d_) __bool__ = __nonzero__ def __dir__(self): # introspection for auto-complete in IPython etc return list(self._d_.keys()) def __eq__(self, other): if isinstance(other, AttrDict): return other._d_ == self._d_ # make sure we still equal to a dict with the same data return other == self._d_ def __repr__(self): r = repr(self._d_) if len(r) > 60: r = r[:60] + '...}' return r def __getattr__(self, attr_name): try: return _wrap(self._d_[attr_name]) except KeyError: raise AttributeError( '%r object has no attribute %r' % (self.__class__.__name__, attr_name)) def __delattr__(self, attr_name): try: del self._d_[attr_name] except KeyError: raise AttributeError( '%r object has no attribute %r' % (self.__class__.__name__, attr_name)) def __getitem__(self, key): return _wrap(self._d_[key]) def __setitem__(self, key, value): self._d_[key] = value def __delitem__(self, key): del self._d_[key] def __setattr__(self, name, value): if name in self._d_ or not hasattr(self.__class__, name): self._d_[name] = value else: # there is an attribute on the class (could be property, ..) - don't add it as field super(AttrDict, self).__setattr__(name, value) def __iter__(self): return iter(self._d_) def to_dict(self): return self._d_ class DslMeta(type): """ Base Metaclass for DslBase subclasses that builds a registry of all classes for given DslBase subclass (== all the query types for the Query subclass of DslBase). It then uses the information from that registry (as well as `name` and `shortcut` attributes from the base class) to construct any subclass based on it's name. For typical use see `QueryMeta` and `Query` in `elasticsearch_dsl.query`. """ _types = {} def __init__(cls, name, bases, attrs): super(DslMeta, cls).__init__(name, bases, attrs) # skip for DslBase if not hasattr(cls, '_type_shortcut'): return if cls.name is None: # abstract base class, register it's shortcut cls._types[cls._type_name] = cls._type_shortcut # and create a registry for subclasses if not hasattr(cls, '_classes'): cls._classes = {} elif cls.name not in cls._classes: # normal class, register it cls._classes[cls.name] = cls @classmethod def get_dsl_type(cls, name): try: return cls._types[name] except KeyError: raise UnknownDslObject('DSL type %s does not exist.' % name) @add_metaclass(DslMeta) class DslBase(object): """ Base class for all DSL objects - queries, filters, aggregations etc. Wraps a dictionary representing the object's json. Provides several feature: - attribute access to the wrapped dictionary (.field instead of ['field']) - _clone method returning a deep copy of self - to_dict method to serialize into dict (to be sent via elasticsearch-py) - basic logical operators (&, | and ~) using a Bool(Filter|Query) TODO: move into a class specific for Query/Filter - respects the definiton of the class and (de)serializes it's attributes based on the `_param_defs` definition (for example turning all values in the `must` attribute into Query objects) """ _param_defs = {} @classmethod def get_dsl_class(cls, name): try: return cls._classes[name] except KeyError: raise UnknownDslObject('DSL class `%s` does not exist in %s.' % (name, cls._type_name)) def __init__(self, **params): self._params = {} for pname, pvalue in iteritems(params): if '__' in pname: pname = pname.replace('__', '.') self._setattr(pname, pvalue) def _repr_params(self): """ Produce a repr of all our parameters to be used in __repr__. """ return ', '.join( '%s=%r' % (n.replace('.', '__'), v) for (n, v) in sorted(iteritems(self._params)) # make sure we don't include empty typed params if 'type' not in self._param_defs.get(n, {}) or v ) def __repr__(self): return '%s(%s)' % ( self.__class__.__name__, self._repr_params() ) def __eq__(self, other): return isinstance(other, self.__class__) and other.to_dict() == self.to_dict() def __ne__(self, other): return not self == other def __setattr__(self, name, value): if name.startswith('_'): return super(DslBase, self).__setattr__(name, value) return self._setattr(name, value) def _setattr(self, name, value): # if this attribute has special type assigned to it... if name in self._param_defs: pinfo = self._param_defs[name] if 'type' in pinfo: # get the shortcut used to construct this type (query.Q, aggs.A, etc) shortcut = self.__class__.get_dsl_type(pinfo['type']) if pinfo.get('multi'): value = list(map(shortcut, value)) # dict(name -> DslBase), make sure we pickup all the objs elif pinfo.get('hash'): value = dict((k, shortcut(v)) for (k, v) in iteritems(value)) # single value object, just convert else: value = shortcut(value) self._params[name] = value def __getattr__(self, name): if name.startswith('_'): raise AttributeError( '%r object has no attribute %r' % (self.__class__.__name__, name)) value = None try: value = self._params[name] except KeyError: # compound types should never throw AttributeError and return empty # container instead if name in self._param_defs: pinfo = self._param_defs[name] if pinfo.get('multi'): value = self._params.setdefault(name, []) elif pinfo.get('hash'): value = self._params.setdefault(name, {}) if value is None: raise AttributeError( '%r object has no attribute %r' % (self.__class__.__name__, name)) # wrap nested dicts in AttrDict for convenient access if isinstance(value, dict): return AttrDict(value) return value def to_dict(self): """ Serialize the DSL object to plain dict """ d = {} for pname, value in iteritems(self._params): pinfo = self._param_defs.get(pname) # typed param if pinfo and 'type' in pinfo: # don't serialize empty lists and dicts for typed fields if value in ({}, []): continue # multi-values are serialized as list of dicts if pinfo.get('multi'): value = list(map(lambda x: x.to_dict(), value)) # squash all the hash values into one dict elif pinfo.get('hash'): value = dict((k, v.to_dict()) for k, v in iteritems(value)) # serialize single values else: value = value.to_dict() # serialize anything with to_dict method elif hasattr(value, 'to_dict'): value = value.to_dict() d[pname] = value return {self.name: d} def _clone(self): return self._type_shortcut(self.to_dict()) def __add__(self, other): # make sure we give queries that know how to combine themselves # preference if hasattr(other, '__radd__'): return other.__radd__(self) return self._bool(must=[self, other]) def __invert__(self): return self._bool(must_not=[self]) def __or__(self, other): # make sure we give queries that know how to combine themselves # preference if hasattr(other, '__ror__'): return other.__ror__(self) return self._bool(should=[self, other]) def __and__(self, other): # make sure we give queries that know how to combine themselves # preference if hasattr(other, '__rand__'): return other.__rand__(self) return self._bool(must=[self, other]) class BoolMixin(object): """ Mixin containing all the operator overrides for Bool queries and filters. Except for and where should behavior differs """ def __add__(self, other): q = self._clone() if isinstance(other, self.__class__): q.must += other.must q.should += other.should q.must_not += other.must_not else: q.must.append(other) return q __radd__ = __add__ def __or__(self, other): if not (self.must or self.must_not): # TODO: if only 1 in must or should, append the query instead of other q = self._clone() q.should.append(other) return q elif isinstance(other, self.__class__) and not (other.must or other.must_not): # TODO: if only 1 in must or should, append the query instead of self q = other._clone() q.should.append(self) return q return self.__class__(should=[self, other]) __ror__ = __or__ def __invert__(self): # special case for single negated query if not (self.must or self.should) and len(self.must_not) == 1: return self.must_not[0]._clone() # bol without should, just flip must and must_not elif not self.should: q = self._clone() q.must, q.must_not = q.must_not, q.must return q # TODO: should -> must_not.append(self.__class__(should=self.should)) ?? # queries with should just invert normally return super(BoolMixin, self).__invert__() class ObjectBase(AttrDict): def __init__(self, **kwargs): m = self._doc_type.mapping for k in m: if k in kwargs and m[k]._coerce: kwargs[k] = m[k].to_python(kwargs[k]) super(ObjectBase, self).__init__(kwargs) def __getattr__(self, name): try: return super(ObjectBase, self).__getattr__(name) except AttributeError: if name in self._doc_type.mapping: f = self._doc_type.mapping[name] if hasattr(f, 'empty'): value = f.empty() if value not in SKIP_VALUES: setattr(self, name, value) value = getattr(self, name) return value raise def __setattr__(self, name, value): if name in self._doc_type.mapping: value = self._doc_type.mapping[name].to_python(value) super(ObjectBase, self).__setattr__(name, value) def to_dict(self): out = {} for k, v in iteritems(self._d_): if isinstance(v, (AttrList, list, tuple)): v = [i.to_dict() if hasattr(i, 'to_dict') else i for i in v] else: v = v.to_dict() if hasattr(v, 'to_dict') else v # don't serialize empty values # careful not to include numeric zeros if v in ([], {}, None): continue out[k] = v return out def clean_fields(self): errors = {} for name in self._doc_type.mapping: field = self._doc_type.mapping[name] data = self._d_.get(name, None) try: # save the cleaned value self._d_[name] = field.clean(data) except ValidationException as e: errors.setdefault(name, []).append(e) if errors: raise ValidationException(errors) def clean(self): pass def full_clean(self): self.clean_fields() self.clean() def merge(data, new_data): if not (isinstance(data, (AttrDict, dict)) and isinstance(new_data, (AttrDict, dict))): raise ValueError('You can only merge two dicts! Got %r and %r instead.' % (data, new_data)) for key, value in iteritems(new_data): if key in data and isinstance(data[key], (AttrDict, dict)): merge(data[key], value) else: data[key] = value