from __future__ import unicode_literals import logging import traceback from datetime import datetime from bson.dbref import DBRef from flask_login import current_user from flask_mongoengine.wtf import model_form from mongoengine import * from wtforms.fields import HiddenField as WTFHiddenField from wtforms.fields import StringField as WTFStringField from core.config.celeryctl import celery_app from core.database import Node, AttachedFile, TagListField from core.group import Group from core.scheduling import OneShotEntry from core.user import User class InvestigationLink(EmbeddedDocument): id = StringField(required=True) fromnode = StringField(required=True) tonode = StringField(required=True) label = StringField() @staticmethod def build(data): link = InvestigationLink( id=data['id'], fromnode=data['from'], tonode=data['to']) if 'label' in data: link.label = data['label'] return link class InvestigationEvent(EmbeddedDocument): kind = StringField(required=True) links = ListField(EmbeddedDocumentField(InvestigationLink)) nodes = ListField(ReferenceField('Node', dbref=True)) datetime = DateTimeField(default=datetime.utcnow) class Investigation(Node): name = StringField(verbose_name="Name") description = StringField(verbose_name="Description") links = ListField(EmbeddedDocumentField(InvestigationLink)) nodes = ListField(ReferenceField('Node', dbref=True)) events = ListField(EmbeddedDocumentField(InvestigationEvent)) created_by = StringField(verbose_name="Created By") created = DateTimeField(default=datetime.utcnow) updated = DateTimeField(default=datetime.utcnow) import_document = ReferenceField('AttachedFile') import_md = StringField() import_url = StringField() import_text = StringField() tags = ListField(StringField(), verbose_name="Relevant tags") sharing = ListField(StringField()) exclude_fields = [ 'links', 'nodes', 'events', 'created', 'updated', 'created_by', 'import_document', 'import_md', 'import_url', 'import_text', 'attached_files' ] # Ignore extra fields meta = { 'strict': False, "indexes": ["tags", "sharing"], } @classmethod def get_form(klass): """Gets the appropriate form for a given investigation""" form = model_form(klass, exclude=klass.exclude_fields) # An empty name is the same as no name form.name = WTFStringField( 'Name', filters=[lambda name: name or None]) form.created_by = WTFHiddenField( 'created_by', default=current_user.username) form.tags = TagListField("Tags") return form SEARCH_ALIASES = {} def info(self): result = self.to_mongo() result['nodes'] = [node.to_mongo() for node in self.nodes] shared = [] for sharing_id in Investigation.objects.get(id=self.id).sharing: try: shared.append(Group.objects.get(id=sharing_id)) except: try: shared.append(User.objects.get(id=sharing_id)) except: pass if shared: result['shared'] = shared return result def _node_changes(self, kind, method, links, nodes): event = InvestigationEvent(kind=kind) for link in links: link = InvestigationLink.build(link) if method('links', link.to_mongo()): event.links.append(link) for node in nodes: if not isinstance(node, DBRef): node = node.to_dbref() if method('nodes', node): event.nodes.append(node) if len(event.nodes) > 0 or len(event.links) > 0: self.modify(push__events=event, updated=datetime.utcnow()) def add(self, links, nodes): self._node_changes('add', self.add_to_set, links, nodes) def remove(self, links, nodes): self._node_changes('remove', self.remove_from_set, links, nodes) def save(self, *args, **kwargs): self.updated = datetime.utcnow() return super(Investigation, self).save(*args, **kwargs) def sharing_permissions(self, sharing_with, investigation=False, invest_id=False): groups = False if sharing_with == "all": Investigation.objects.get(id=invest_id or self.id).update(set__sharing=[]) elif sharing_with == "private": Investigation.objects.get(id=invest_id or self.id).update(add_to_set__sharing=[current_user.id]) elif sharing_with == "allg": groups = Group.objects(members__in=[current_user.id]) else: groups = Group.objects(id=sharing_with) if groups: Investigation.objects.get(id=self.id).update(add_to_set__sharing=[group.id for group in groups]) class ImportResults(Document): import_method = ReferenceField('ImportMethod', required=True) status = StringField(required=True) investigation = ReferenceField('Investigation') error = StringField() class ImportMethod(OneShotEntry): acts_on = StringField() def run(self, target): results = ImportResults(import_method=self, status='pending') results.investigation = Investigation(created_by=current_user.username) if isinstance(target, AttachedFile): results.investigation.import_document = target target = target.filepath else: results.investigation.import_url = target results.investigation.save() results.save() celery_app.send_task( "core.investigation.import_task", [str(results.id), target]) return results @celery_app.task def import_task(results_id, target): results = ImportResults.objects.get(id=results_id) import_method = results.import_method logging.warning( "Running one-shot import {} on {}".format( import_method.__class__.__name__, target)) results.update(status="running") try: import_method.do_import(results, target) results.update(status="finished") except Exception as e: results.update(status="error", error=str(e)) traceback.print_exc()