import csv import os import sys import time import collections import threading import logging from ztag.errors import IgnoreObject class Updater(object): """ Updater encapsulates the behavior for the updates.csv file; put_update() is called with each update, but output is only written every :frequency: seconds. """ def __init__(self, output=None, frequency=1.0, logger=None): self.output = output self.frequency = frequency self.logger = logger self.prev = None self._wrote_labels = False def put_update(self, row): if not self.output: return if self.prev and (row.time - self.prev.time) < self.frequency: return self.prev = row if not self._wrote_labels: self.output.write(row.get_csv_labels() + "\n") self._wrote_labels = True self.output.write(row.get_csv() + "\n") self.output.flush() def close(self): if self.output and self.output != sys.stderr: try: self.output.close() except BaseException as e: if self.logger: self.logger.warn("Failed to close updates CSV stream: %s", str(e)) class UpdateRow(object): """ UpdateRow encapsulates the information for a single update and the logic for outputting it as a CSV row. """ ORDER = ("skipped", "handled", "delta_skipped", "delta_handled") def __init__(self, skipped, handled, updated_at=None, prev=None): """ Construct a new row with the given number of skipped / handled entries, and calculate the deltas from prev (or set them to 0). Also sets time to now. :param skipped: current total number of skipped records :param handled: current total number of handled records :param prev: the previous UpdateRow """ self.time = updated_at or time.time() self.skipped = skipped self.handled = handled if prev: self.delta_skipped = skipped - prev.skipped self.delta_handled = handled - prev.handled else: self.delta_skipped = 0 self.delta_handled = 0 @classmethod def get_csv_labels(cls): return ",".join(cls.ORDER) def get_csv(self): return ",".join(str(getattr(self, label)) for label in self.ORDER) class Stream(object): def __init__(self, incoming, outgoing, transforms=None, logger=None, updates=None): super(Stream, self).__init__() self.incoming = incoming self.outgoing = outgoing self.transforms = transforms or list() self.logger = logger if updates: self.updater = Updater(output=updates, frequency=1.0, logger=logger) else: self.updater = None def put_update(self, skipped, handled): if not self.updater: return this_update = UpdateRow(skipped=skipped, handled=handled, prev=self.updater.prev) self.updater.put_update(this_update) def run(self): skipped = 0 handled = 0 for obj in self.incoming: self.put_update(handled=handled, skipped=skipped) try: out = obj for transformer in self.transforms: out = transformer.transform(out) if out is None: raise IgnoreObject() self.outgoing.take(out) handled += 1 except IgnoreObject as e: if self.logger: self.logger.debug(e.original_exception) self.logger.trace(obj) if e.trback: self.logger.warn(e.trback) skipped += 1 continue self.outgoing.cleanup() if self.updater: self.updater.close() return (handled, skipped) class Incoming(object): pass class InputFile(Incoming): def __init__(self, input_file=sys.stdin): self.input_file = input_file def __iter__(self): for line in self.input_file: yield line class InputCSV(Incoming): def __init__(self, input_file=sys.stdin): self.input_file = input_file self.csvdict = csv.DictReader(self.input_file) def __iter__(self): for record in self.csvdict: yield record class Outgoing(object): def __init__(self, *args, **kwargs): pass def take(self, obj): raise NotImplementedError def cleanup(self): pass class PythonPrint(Outgoing): def __init__(self, *args, **kwargs): super(PythonPrint, self).__init__() def take(self, obj): print obj class OutputFile(Outgoing): def __init__(self, output_file=sys.stdout, *args, **kwargs): super(OutputFile, self).__init__() self.output_file = output_file def take(self, obj): self.output_file.write(obj) self.output_file.write("\n") class RedisQueue(Outgoing): CERTIFICATES_QUEUE = "certificate" PUBKEY_QUEUE = "pubkey" # we might as well try to do a whole bunch. The _worst_ case scenario # by setting a limit too high is that the server runs out of memory # and kills python and the task fails. Which would have happened # anyway, because we couldn't connect to redis. MAX_RETRIES = 60 BATCH_SIZE = 250 def __init__(self, logger=None, destination=None, *args, **kwargs): import redis super(RedisQueue, self).__init__(*args, **kwargs) host = os.environ.get('ZTAG_REDIS_HOST', 'localhost') port = int(os.environ.get('ZTAG_REDIS_PORT', 6379)) if destination == "full_ipv4": queue = "ipv4" elif destination == "alexa_top1mil": queue = "domain" else: raise Exception("invalid destination: %s" % destination) self.logger = logger self.queue = queue try: self.redis = redis.Redis(host=host, port=port, db=0, socket_connect_timeout=10) except redis.ConnectionError as e: msg = "could not connect to redis: %s" % str(e) self.logger.fatal(msg) # batching self.queued = 0 self.retries = 0 self.records = [] self.certificates = [] def push(self, noretry=False): import redis if self.queued == 0: return try: p = self.redis.pipeline() for r in self.records: p.rpush(self.queue, r) for r in self.certificates: p.rpush(self.CERTIFICATES_QUEUE, r) p.execute() self.queued = 0 self.records = [] self.certificates = [] self.retries = 0 except redis.ConnectionError as e: time.sleep(1.0) self.retries += 1 if self.retries > self.MAX_RETRIES or noretry: msg = "redis connection error: %s" % str(e) self.logger.fatal(msg) self.redis = None def take(self, pbout): self.records.append(pbout.transformed) self.certificates.extend(pbout.certificates) self.queued += (len(pbout.certificates) + 1) if self.queued > self.BATCH_SIZE: self.push() def cleanup(self): return self.push(noretry=True) class Kafka(Outgoing): def __init__(self, logger=None, destination=None, *args, **kwargs): from kafka import KafkaProducer if destination == "full_ipv4": self.topic = "ipv4" elif destination == "alexa_top1mil": self.topic = "domain" else: raise Exception("invalid destination: %s" % destination) host = os.environ.get('KAFKA_BOOTSTRAP_HOST', 'localhost:9092') self.main_producer = KafkaProducer(bootstrap_servers=host) self.cert_producer = KafkaProducer(bootstrap_servers=host) def take(self, pbout): for certificate in pbout.certificates: self.cert_producer.send("certificate", certificate) self.main_producer.send(self.topic, pbout.transformed) def cleanup(self): if self.main_producer: self.main_producer.flush() if self.cert_producer: self.cert_producer.flush() failed_msg_t = collections.namedtuple('failed_msg_t', 'topic msg attempt') class PubsubState(): ''' Hold state with single course-grained lock. Restrict to safe operations on shared memory. ''' def __init__(self): self._lock = threading.Lock() self._npending_msgs = 0 self._failed_msgs = [] # An individual thread raising an exception or calling # sys.exit() will only end that thread. Use this to # signal the rest of the threads to exit. self.exit_exception = None def inc_npending(self): self._lock.acquire() self._npending_msgs += 1 self._lock.release() def dec_npending(self): self._lock.acquire() self._npending_msgs -= 1 self._lock.release() def get_npending(self): ''' No lock required to simply read int; no direct writes allowed. ''' return self._npending_msgs def add_failed_msg(self, topic, msg, attempt): self._lock.acquire() self._failed_msgs.append(failed_msg_t(topic, msg, attempt)) self._lock.release() def retrieve_failed_msgs(self): ''' Retrieve list of failed messages and reset running list. Returned value is no longer shared data. ''' self._lock.acquire() retval = self._failed_msgs self._failed_msgs = [] self._lock.release() return retval class Pubsub(Outgoing): MAX_ATTEMPTS = 5 def __init__(self, logger=None, destination=None, *args, **kwargs): import google from google.cloud import pubsub, pubsub_v1 self.logger = logger if logger is None: self.logger = logging.getLogger('null-logger') self.logger.setLevel(9999) if destination == "full_ipv4": self.topic_url = os.environ.get('PUBSUB_IPV4_TOPIC_URL') elif destination == "alexa_top1mil": self.topic_url = os.environ.get('PUBSUB_ALEXA_TOPIC_URL') self.cert_topic_url = os.environ.get('PUBSUB_CERT_TOPIC_URL') if not self.topic_url: raise Exception('missing $PUBSUB_[IPV4|ALEXA]_TOPIC_URL') if not self.cert_topic_url: raise Exception('missing $PUBSUB_CERT_TOPIC_URL') batch_settings = pubsub_v1.types.BatchSettings( # "The entire request including one or more messages must # be smaller than 10MB, after decoding." max_bytes=8192000, # 8 MB max_latency=15, # 15 seconds ) self.publisher = pubsub.PublisherClient(batch_settings) self.publish_count = {} try: self.publisher.get_topic(self.topic_url) self.publisher.get_topic(self.cert_topic_url) except google.api_core.exceptions.GoogleAPICallError as e: logger.error(e.message) raise self._state = PubsubState() def _make_done_callback(self, topic, data, attempt): def done_callback(future): if self._state.exit_exception: sys.exit(1) exception = future.exception() if not exception: self.logger.debug("Publish attempt #{attempt}/{max} on topic '{topic}' " "succeeded.".format(attempt=attempt + 1, max=self.MAX_ATTEMPTS, topic=topic)) self._state.dec_npending() else: self.logger.error("Publish attempt #{attempt}/{max} failed for data '{data}' on" "topic '{topic}' {error}" .format(attempt=attempt + 1, max=self.MAX_ATTEMPTS, data=data, topic=topic, error=str(exception))) if attempt >= self.MAX_ATTEMPTS: self._state.exit_exception = exception sys.exit(1) self._state.add_failed_msg(topic, data, attempt + 1) return done_callback def _publish_with_callback(self, topic, data, attempt): if attempt == 0: self._state.inc_npending() cb = self._make_done_callback(topic, data, attempt) publish_future = self.publisher.publish(topic, data) publish_future.add_done_callback(cb) def take(self, pbout): for certificate in pbout.certificates: self._publish_with_callback(self.cert_topic_url, certificate, 0) self._publish_with_callback(self.topic_url, pbout.transformed, 0) def cleanup(self): while self._state.get_npending() > 0: time.sleep(10) if self._state.exit_exception: self.logger.error("Max attempts exceeded; raising most recent exception.") raise self._state.exit_exception failed_msgs = self._state.retrieve_failed_msgs() self.logger.debug("Failed message queuelen: {}, " "messages pending: {}" .format(len(failed_msgs), self._state.get_npending())) for failed in failed_msgs: self._publish_with_callback(failed.topic, failed.msg, failed.attempt + 1) self.logger.debug("Pubsub cleanup: Finished.")