# Copyright 2015 Palo Alto Networks, Inc # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import netaddr import uuid import shutil from . import base from . import actorbase from . import table from . import st from .utils import utc_millisec from .utils import RESERVED_ATTRIBUTES LOG = logging.getLogger(__name__) WL_LEVEL = st.MAX_LEVEL class MWUpdate(object): def __init__(self, start, end, uuids): self.start = start self.end = end self.uuids = set(uuids) s = netaddr.IPAddress(start) e = netaddr.IPAddress(end) self._indicator = '%s-%s' % (s, e) def indicator(self): return self._indicator def __repr__(self): return 'MWUpdate('+self._indicator+', %r)' % self.uuids def __hash__(self): return hash(self._indicator) def __eq__(self, other): return self.start == other.start and \ self.end == other.end class AggregateIPv4FT(actorbase.ActorBaseFT): def __init__(self, name, chassis, config): self.active_requests = [] super(AggregateIPv4FT, self).__init__(name, chassis, config) def configure(self): super(AggregateIPv4FT, self).configure() self.whitelist_prefixes = self.config.get('whitelist_prefixes', []) self.enable_list_merge = self.config.get('enable_list_merge', False) def _initialize_tables(self, truncate=False): self.table = table.Table( self.name, bloom_filter_bits=10, truncate=truncate ) self.table.create_index('_id') self.st = st.ST(self.name+'_st', 32, truncate=truncate) def initialize(self): self._initialize_tables() def rebuild(self): self._initialize_tables(truncate=True) def reset(self): self._initialize_tables(truncate=True) def _indicator_key(self, indicator, source): return indicator+'\x00'+source def _calc_indicator_value(self, uuids, additional_uuid=None, additional_value=None): mv = {'sources': []} for uuid_ in uuids: if uuid_ == additional_uuid: v = additional_value else: # uuid_ = str(uuid.UUID(bytes=uuid_)) k, v = next( self.table.query('_id', from_key=uuid_, to_key=uuid_, include_value=True), (None, None) ) if k is None: LOG.error("Unable to find key associated with uuid: %s", uuid_) for vk in v: if vk in mv and vk in RESERVED_ATTRIBUTES: mv[vk] = RESERVED_ATTRIBUTES[vk](mv[vk], v[vk]) else: if self.enable_list_merge and vk in mv and isinstance(mv[vk], list): if not isinstance(v[vk], list): mv[vk] = v[vk] else: mv[vk].extend(v[vk]) else: mv[vk] = v[vk] return mv def _merge_values(self, origin, ov, nv): result = {'sources': []} result['_added'] = ov['_added'] result['_id'] = ov['_id'] for k in nv.keys(): result[k] = nv[k] return result def _add_indicator(self, origin, indicator, value): added = False now = utc_millisec() ik = self._indicator_key(indicator, origin) v = self.table.get(ik) if v is None: v = { '_id': str(uuid.uuid4()), '_added': now } added = True self.statistics['added'] += 1 v = self._merge_values(origin, v, value) v['_updated'] = now self.table.put(ik, v) return v, added def _calc_ipranges(self, start, end): """Calc IP Ranges overlapping the range between start and end Args: start (int): start of the range end (int): end of the range Returns: set: set of ranges """ result = set() # collect the endpoint between start and end eps = set() for epaddr, _, _, _ in self.st.query_endpoints(start=start, stop=end): eps.add(epaddr) eps = sorted(eps) if len(eps) == 0: return result # walk thru the endpoints, tracking last endpoint # current level, active segments and segments levels oep = None oeplevel = -1 live_ids = set() slevels = {} for epaddr in eps: # for each endpoint we track which segments are starting # and which ones are ending with that specific endpoint end_ids = set() start_ids = set() eplevel = 0 for cuuid, clevel, cstart, cend in self.st.cover(epaddr): slevels[cuuid] = clevel if clevel > eplevel: eplevel = clevel if cstart == epaddr: start_ids.add(cuuid) if cend == epaddr: end_ids.add(cuuid) if cend != epaddr and cstart != epaddr: if cuuid not in live_ids: assert epaddr == eps[0] live_ids.add(cuuid) assert len(end_ids) + len(start_ids) > 0 if len(start_ids) != 0: if oep is not None and oep != epaddr and len(live_ids) != 0: if oeplevel != WL_LEVEL: result.add(MWUpdate(oep, epaddr-1, live_ids)) oep = epaddr oeplevel = eplevel live_ids = live_ids | start_ids if len(end_ids) != 0: if oep is not None and len(live_ids) != 0: if eplevel < WL_LEVEL: result.add(MWUpdate(oep, epaddr, live_ids)) oep = epaddr+1 live_ids = live_ids - end_ids oeplevel = eplevel if len(live_ids) != 0: oeplevel = max([slevels[id_] for id_ in live_ids]) return result def _range_from_indicator(self, indicator): if '-' in indicator: start, end = map( lambda x: int(netaddr.IPAddress(x)), indicator.split('-', 1) ) elif '/' in indicator: ipnet = netaddr.IPNetwork(indicator) start = int(ipnet.ip) end = start+ipnet.size-1 else: start = int(netaddr.IPAddress(indicator)) end = start if (not (start >= 0 and start <= 0xFFFFFFFF)) or \ (not (end >= 0 and end <= 0xFFFFFFFF)): LOG.error('%s - {%s} invalid IPv4 indicator', self.name, indicator) return None, None return start, end def _endpoints_from_range(self, start, end): """Return last endpoint before range and first endpoint after range Args: start (int): range start end (int): range stop Returns: tuple: (last endpoint before, first endpoint after) """ rangestart = next( self.st.query_endpoints(start=0, stop=max(start-1, 0), reverse=True), None ) if rangestart is not None: rangestart = rangestart[0] LOG.debug('%s - range start: %s', self.name, rangestart) rangestop = next( self.st.query_endpoints(reverse=False, start=min(end+1, self.st.max_endpoint), stop=self.st.max_endpoint, include_start=False), None ) if rangestop is not None: rangestop = rangestop[0] LOG.debug('%s - range stop: %s', self.name, rangestop) return rangestart, rangestop @base._counting('update.processed') def filtered_update(self, source=None, indicator=None, value=None): vtype = value.get('type', None) if vtype != 'IPv4': self.statistics['update.ignored'] += 1 return v, newindicator = self._add_indicator(source, indicator, value) start, end = self._range_from_indicator(indicator) if start is None or end is None: return level = 1 for p in self.whitelist_prefixes: if source.startswith(p): level = WL_LEVEL break LOG.debug("%s - update: indicator: (%s) %s %s level: %s", self.name, indicator, start, end, level) rangestart, rangestop = self._endpoints_from_range(start, end) rangesb = set(self._calc_ipranges(rangestart, rangestop)) LOG.debug('%s - ranges before update: %s', self.name, rangesb) if not newindicator and level != WL_LEVEL: for u in rangesb: self.emit_update( u.indicator(), self._calc_indicator_value(u.uuids) ) return uuidbytes = v['_id'] self.st.put(uuidbytes, start, end, level=level) rangesa = set(self._calc_ipranges(rangestart, rangestop)) LOG.debug('%s - ranges after update: %s', self.name, rangesa) added = rangesa-rangesb LOG.debug("%s - IP ranges added: %s", self.name, added) removed = rangesb-rangesa LOG.debug("%s - IP ranges removed: %s", self.name, removed) for u in added: self.emit_update( u.indicator(), self._calc_indicator_value(u.uuids) ) for u in rangesa - added: for ou in rangesb: if u == ou and len(u.uuids ^ ou.uuids) != 0: LOG.debug("IP range updated: %s", repr(u)) self.emit_update( u.indicator(), self._calc_indicator_value(u.uuids) ) for u in removed: self.emit_withdraw( u.indicator(), value=self._calc_indicator_value(u.uuids) ) @base._counting('withdraw.processed') def filtered_withdraw(self, source=None, indicator=None, value=None): LOG.debug("%s - withdraw from %s - %s", self.name, source, indicator) if value is not None and value.get('type', None) != 'IPv4': self.statistics['withdraw.ignored'] += 1 return ik = self._indicator_key(indicator, source) v = self.table.get(ik) LOG.debug("%s - v: %s", self.name, v) if v is None: return self.table.delete(ik) self.statistics['removed'] += 1 start, end = self._range_from_indicator(indicator) if start is None or end is None: return level = 1 for p in self.whitelist_prefixes: if source.startswith(p): level = WL_LEVEL break rangestart, rangestop = self._endpoints_from_range(start, end) rangesb = set(self._calc_ipranges(rangestart, rangestop)) LOG.debug("ranges before: %s", rangesb) uuidbytes = v['_id'] self.st.delete(uuidbytes, start, end, level=level) rangesa = set(self._calc_ipranges(rangestart, rangestop)) LOG.debug("ranges after: %s", rangesa) added = rangesa-rangesb LOG.debug("IP ranges added: %s", added) removed = rangesb-rangesa LOG.debug("IP ranges removed: %s", removed) for u in added: self.emit_update( u.indicator(), self._calc_indicator_value(u.uuids) ) for u in rangesa - added: for ou in rangesb: if u == ou and len(u.uuids ^ ou.uuids) != 0: LOG.debug("IP range updated: %s", repr(u)) self.emit_update( u.indicator(), self._calc_indicator_value(u.uuids) ) for u in removed: self.emit_withdraw( u.indicator(), value=self._calc_indicator_value( u.uuids, additional_uuid=v['_id'], additional_value=v ) ) def _send_indicators(self, source=None, from_key=None, to_key=None): if from_key is None: from_key = 0 if to_key is None: to_key = 0xFFFFFFFF result = self._calc_ipranges(from_key, to_key) for u in result: self.do_rpc( source, "update", indicator=u.indicator(), value=self._calc_indicator_value(u.uuids) ) def get(self, source=None, indicator=None): if not type(indicator) in [str, unicode]: raise ValueError("Invalid indicator type") indicator = int(netaddr.IPAddress(indicator)) result = self._calc_ipranges(indicator, indicator) if len(result) == 0: return None u = result.pop() return self._calc_indicator_value(u.uuids) def get_all(self, source=None): self._send_indicators(source=source) return 'OK' def get_range(self, source=None, index=None, from_key=None, to_key=None): if index is not None: raise ValueError('Index not found') if from_key is not None: from_key = int(netaddr.IPAddress(from_key)) if to_key is not None: to_key = int(netaddr.IPAddress(to_key)) self._send_indicators( source=source, from_key=from_key, to_key=to_key ) return 'OK' def length(self, source=None): return self.table.num_indicators def stop(self): super(AggregateIPv4FT, self).stop() for g in self.active_requests: g.kill() self.active_requests = [] self.table.close() LOG.info("%s - # indicators: %d", self.name, self.table.num_indicators) @staticmethod def gc(name, config=None): actorbase.ActorBaseFT.gc(name, config=config) shutil.rmtree(name, ignore_errors=True) shutil.rmtree('{}_st'.format(name), ignore_errors=True)