"""Module with classes to read and store data about work entities. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function from io import StringIO import pickle import random import time import numpy as np from six import iteritems from six import itervalues from six import PY3 from six import text_type # Cloud Datastore constants KIND_WORK_TYPE = u'WorkType' KIND_WORK = u'Work' ID_ATTACKS_WORK_ENTITY = u'AllAttacks' ID_DEFENSES_WORK_ENTITY = u'AllDefenses' ATTACK_WORK_ID_PATTERN = u'WORKA{:03}' DEFENSE_WORK_ID_PATTERN = u'WORKD{:05}' # Constants for __str__ TO_STR_MAX_WORK = 20 # How long worker is allowed to process one piece of work, # before considered failed MAX_PROCESSING_TIME = 600 # Number of work records to read at once MAX_WORK_RECORDS_READ = 1000 def get_integer_time(): """Returns current time in long integer format.""" if PY3: return int(time.time()) else: return long(time.time()) def is_unclaimed(work): """Returns True if work piece is unclaimed.""" if work['is_completed']: return False cutoff_time = time.time() - MAX_PROCESSING_TIME if (work['claimed_worker_id'] and work['claimed_worker_start_time'] is not None and work['claimed_worker_start_time'] >= cutoff_time): return False return True class WorkPiecesBase(object): """Base class to store one piece of work. In adversarial competition, all work consists of the following: - evaluation of all attacks on all images from dataset which results in generation of adversarial images; - evaluation of all defenses on all adversarial images which results in storing classification labels. One piece of work is either evaluation of one attack on a subset of images or evaluation of one defense on a subset of adversarial images. This way all work is split into work pieces which could be computed independently in parallel by different workers. Each work piece is identified by unique ID and has one of the following statuses: - Unclaimed. This means that no worker has started working on the work piece. - Claimed by worker NN. This means that worker NN is working on this work piece. After workpiece being claimed for too long (more than MAX_PROCESSING_TIME seconds) it automatically considered unclaimed. This is needed in case worker failed while processing the work piece. - Completed. This means that computation of work piece is done. Additionally each work piece may be assigned to a shard. In such case workers are also grouped into shards. Each time worker looking for a work piece it first tries to find undone work from the shard worker is assigned to. Only after all work from this shard is done, worker will try to claim work pieces from other shards. The purpose of sharding is to reduce load on Google Cloud Datastore. """ def __init__(self, datastore_client, work_type_entity_id): """Initializes WorkPiecesBase class. Args: datastore_client: instance of CompetitionDatastoreClient. work_type_entity_id: ID of the WorkType parent entity """ self._datastore_client = datastore_client self._work_type_entity_id = work_type_entity_id # Dictionary: work_id -> dict with properties of the piece of work # # Common properties are following: # - claimed_worker_id - worker id which claimed the work # - claimed_worker_start_time - when work was claimed # - is_completed - whether work is completed or not # - error - if not None then work was completed with error # - elapsed_time - time took to complete the work # - shard_id - ID of the shard which run the work # - submission_id - ID of the submission which should be executed # # Additionally piece of work will have property specific to work type: # output_adversarial_batch_id for attack and output_classification_batch_id # for defense. Also upon completion of the work, worker may write # additional statistics field to the work. self._work = {} def serialize(self, fobj): """Serialize work pieces into file object.""" pickle.dump(self._work, fobj) def deserialize(self, fobj): """Deserialize work pieces from file object.""" self._work = pickle.load(fobj) @property def work(self): """Dictionary with all work pieces.""" return self._work def replace_work(self, value): """Replaces work with provided value. Generally this method should be called only by master, that's why it separated from the property self.work. Args: value: dictionary with new work pieces """ assert isinstance(value, dict) self._work = value def __len__(self): return len(self._work) def is_all_work_competed(self): """Returns whether all work pieces are completed or not.""" return all([w['is_completed'] for w in itervalues(self.work)]) def write_all_to_datastore(self): """Writes all work pieces into datastore. Each work piece is identified by ID. This method writes/updates only those work pieces which IDs are stored in this class. For examples, if this class has only work pieces with IDs '1' ... '100' and datastore already contains work pieces with IDs '50' ... '200' then this method will create new work pieces with IDs '1' ... '49', update work pieces with IDs '50' ... '100' and keep unchanged work pieces with IDs '101' ... '200'. """ client = self._datastore_client with client.no_transact_batch() as batch: parent_key = client.key(KIND_WORK_TYPE, self._work_type_entity_id) batch.put(client.entity(parent_key)) for work_id, work_val in iteritems(self._work): entity = client.entity(client.key(KIND_WORK, work_id, parent=parent_key)) entity.update(work_val) batch.put(entity) def read_all_from_datastore(self): """Reads all work pieces from the datastore.""" self._work = {} client = self._datastore_client parent_key = client.key(KIND_WORK_TYPE, self._work_type_entity_id) for entity in client.query_fetch(kind=KIND_WORK, ancestor=parent_key): work_id = entity.key.flat_path[-1] self.work[work_id] = dict(entity) def _read_undone_shard_from_datastore(self, shard_id=None): """Reads undone worke pieces which are assigned to shard with given id.""" self._work = {} client = self._datastore_client parent_key = client.key(KIND_WORK_TYPE, self._work_type_entity_id) filters = [('is_completed', '=', False)] if shard_id is not None: filters.append(('shard_id', '=', shard_id)) for entity in client.query_fetch(kind=KIND_WORK, ancestor=parent_key, filters=filters): work_id = entity.key.flat_path[-1] self.work[work_id] = dict(entity) if len(self._work) >= MAX_WORK_RECORDS_READ: break def read_undone_from_datastore(self, shard_id=None, num_shards=None): """Reads undone work from the datastore. If shard_id and num_shards are specified then this method will attempt to read undone work for shard with id shard_id. If no undone work was found then it will try to read shard (shard_id+1) and so on until either found shard with undone work or all shards are read. Args: shard_id: Id of the start shard num_shards: total number of shards Returns: id of the shard with undone work which was read. None means that work from all datastore was read. """ if shard_id is not None: shards_list = [(i + shard_id) % num_shards for i in range(num_shards)] else: shards_list = [] shards_list.append(None) for shard in shards_list: self._read_undone_shard_from_datastore(shard) if self._work: return shard return None def try_pick_piece_of_work(self, worker_id, submission_id=None): """Tries pick next unclaimed piece of work to do. Attempt to claim work piece is done using Cloud Datastore transaction, so only one worker can claim any work piece at a time. Args: worker_id: ID of current worker submission_id: if not None then this method will try to pick piece of work for this submission Returns: ID of the claimed work piece """ client = self._datastore_client unclaimed_work_ids = None if submission_id: unclaimed_work_ids = [ k for k, v in iteritems(self.work) if is_unclaimed(v) and (v['submission_id'] == submission_id) ] if not unclaimed_work_ids: unclaimed_work_ids = [k for k, v in iteritems(self.work) if is_unclaimed(v)] if unclaimed_work_ids: next_work_id = random.choice(unclaimed_work_ids) else: return None try: with client.transaction() as transaction: work_key = client.key(KIND_WORK_TYPE, self._work_type_entity_id, KIND_WORK, next_work_id) work_entity = client.get(work_key, transaction=transaction) if not is_unclaimed(work_entity): return None work_entity['claimed_worker_id'] = worker_id work_entity['claimed_worker_start_time'] = get_integer_time() transaction.put(work_entity) except: return None return next_work_id def update_work_as_completed(self, worker_id, work_id, other_values=None, error=None): """Updates work piece in datastore as completed. Args: worker_id: ID of the worker which did the work work_id: ID of the work which was done other_values: dictionary with additonal values which should be saved with the work piece error: if not None then error occurred during computation of the work piece. In such case work will be marked as completed with error. Returns: whether work was successfully updated """ client = self._datastore_client try: with client.transaction() as transaction: work_key = client.key(KIND_WORK_TYPE, self._work_type_entity_id, KIND_WORK, work_id) work_entity = client.get(work_key, transaction=transaction) if work_entity['claimed_worker_id'] != worker_id: return False work_entity['is_completed'] = True if other_values: work_entity.update(other_values) if error: work_entity['error'] = text_type(error) transaction.put(work_entity) except: return False return True def compute_work_statistics(self): """Computes statistics from all work pieces stored in this class.""" result = {} for v in itervalues(self.work): submission_id = v['submission_id'] if submission_id not in result: result[submission_id] = { 'completed': 0, 'num_errors': 0, 'error_messages': set(), 'eval_times': [], 'min_eval_time': None, 'max_eval_time': None, 'mean_eval_time': None, 'median_eval_time': None, } if not v['is_completed']: continue result[submission_id]['completed'] += 1 if 'error' in v and v['error']: result[submission_id]['num_errors'] += 1 result[submission_id]['error_messages'].add(v['error']) else: result[submission_id]['eval_times'].append(float(v['elapsed_time'])) for v in itervalues(result): if v['eval_times']: v['min_eval_time'] = np.min(v['eval_times']) v['max_eval_time'] = np.max(v['eval_times']) v['mean_eval_time'] = np.mean(v['eval_times']) v['median_eval_time'] = np.median(v['eval_times']) return result def __str__(self): buf = StringIO() buf.write(u'WorkType "{0}"\n'.format(self._work_type_entity_id)) for idx, (work_id, work_val) in enumerate(iteritems(self.work)): if idx >= TO_STR_MAX_WORK: buf.write(u' ...\n') break buf.write(u' Work "{0}"\n'.format(work_id)) buf.write(u' {0}\n'.format(str(work_val))) return buf.getvalue() class AttackWorkPieces(WorkPiecesBase): """Subclass which represents work pieces for adversarial attacks.""" def __init__(self, datastore_client): """Initializes AttackWorkPieces.""" super(AttackWorkPieces, self).__init__( datastore_client=datastore_client, work_type_entity_id=ID_ATTACKS_WORK_ENTITY) def init_from_adversarial_batches(self, adv_batches): """Initializes work pieces from adversarial batches. Args: adv_batches: dict with adversarial batches, could be obtained as AversarialBatches.data """ for idx, (adv_batch_id, adv_batch_val) in enumerate(iteritems(adv_batches)): work_id = ATTACK_WORK_ID_PATTERN.format(idx) self.work[work_id] = { 'claimed_worker_id': None, 'claimed_worker_start_time': None, 'is_completed': False, 'error': None, 'elapsed_time': None, 'submission_id': adv_batch_val['submission_id'], 'shard_id': None, 'output_adversarial_batch_id': adv_batch_id, } class DefenseWorkPieces(WorkPiecesBase): """Subclass which represents work pieces for adversarial defenses.""" def __init__(self, datastore_client): """Initializes DefenseWorkPieces.""" super(DefenseWorkPieces, self).__init__( datastore_client=datastore_client, work_type_entity_id=ID_DEFENSES_WORK_ENTITY) def init_from_class_batches(self, class_batches, num_shards=None): """Initializes work pieces from classification batches. Args: class_batches: dict with classification batches, could be obtained as ClassificationBatches.data num_shards: number of shards to split data into, if None then no sharding is done. """ shards_for_submissions = {} shard_idx = 0 for idx, (batch_id, batch_val) in enumerate(iteritems(class_batches)): work_id = DEFENSE_WORK_ID_PATTERN.format(idx) submission_id = batch_val['submission_id'] shard_id = None if num_shards: shard_id = shards_for_submissions.get(submission_id) if shard_id is None: shard_id = shard_idx % num_shards shards_for_submissions[submission_id] = shard_id shard_idx += 1 # Note: defense also might have following fields populated by worker: # stat_correct, stat_error, stat_target_class, stat_num_images self.work[work_id] = { 'claimed_worker_id': None, 'claimed_worker_start_time': None, 'is_completed': False, 'error': None, 'elapsed_time': None, 'submission_id': submission_id, 'shard_id': shard_id, 'output_classification_batch_id': batch_id, }