#!/usr/bin/python # -*- coding: utf-8 -*- import sqlite3 import logging import sys import os.path import json import random import validate # # Initialize the database once when we import the module. # PRJ_FILE = os.path.join('data', 'projects.sqlite') def ip_key(ip): return tuple(int(part) for part in ip.split('.')) class DatabaseException(Exception): pass class Database(): """ Class to handle all database interactions. """ def __init__(self, filename): """ Setup the connection and initialize the database. """ self.log = logging.getLogger('DATABASE') self.valid = validate.Validate() self.filename = filename self.con = sqlite3.connect(self.filename) self.con.row_factory = sqlite3.Row self.cur = self.con.cursor() def __del__(self): """ Clean up the database connection if it exists. """ if self.con is not None: self.con.close() def get_tables(self): """ Get a list of tables in the database. """ stmt = "SELECT name FROM sqlite_master WHERE type='table'" if self.execute_sql(stmt) is True: return [n['name'] for n in self.cur.fetchall()] else: return [] def execute_sql(self, stmt, args=None, commit=True): """ Execute an SQL statement. Attempt to execute an SQL statement and log any errors. Return True if successful and false if not. """ self.log.debug('Executing {0} with args {1}.'.format(stmt, args)) try: if args is None: self.cur.execute(stmt) else: self.cur.execute(stmt, args) if commit is True: self.con.commit() return True except sqlite3.Error as e: self.log.debug(e) return False class ScanDatabase(): """ Class to handle scan data and attack notes. """ def __init__(self, filename): self.log = logging.getLogger('DATABASE') self.itemdb = ItemDatabase(filename) self.attackdb = AttackDatabase(filename) self.hostdb = HostDatabase(filename) self.importdb = ImportDatabase(filename) def get_stats(self): """ Get host and attack stats for the database. """ self.log.debug('Gathering stats.') hosts = len(self.itemdb.get_unique_hosts()) attacks = len(self.attackdb.get_attacks()) return 'Hosts: {0} Attacks {1}'.format(hosts, attacks) def get_host_details(self, ip): """ Get all information associated with an IP. """ host = {'note': '', 'items': []} host['note'] = self.hostdb.get_host_note(ip) host['items'] = self.itemdb.get_items_by_ip(ip) return host def get_summary(self): """ Get summary information for all of the hosts. """ summary = [] hosts = self.itemdb.get_unique_hosts() for host in hosts: h = dict(self.hostdb.get_host(host)) ports = self.itemdb.get_ports_by_ip(host) h['tcp'] = [str(p) for p in sorted(ports['tcp'])] h['udp'] = [str(p) for p in sorted(ports['udp'])] summary.append(h) summary = sorted(summary, key=lambda x: ip_key(x['ip'])) return summary def get_unique(self): unique = {} ips = [ip for ip in self.itemdb.get_unique_hosts()] unique['ip'] = sorted(ips, key=lambda x: ip_key(x)) ports = self.itemdb.get_unique_ports() unique['tcp'] = [str(p) for p in sorted(ports['tcp'])] unique['udp'] = [str(p) for p in sorted(ports['udp'])] return unique class ItemDatabase(Database): """ Class to handle item data. """ def __init__(self, filename): Database.__init__(self, filename) items = ''' CREATE TABLE IF NOT EXISTS items ( id integer primary key autoincrement, ip text, port integer, protocol text, note text, hash text ) ''' ires = self.execute_sql(items) if ires is False: raise DatabaseException('Could not create items table.') def create_item(self, ip, port, protocol, note, hash): """ Add new item. """ self.log.debug('Creating new item.') try: self.valid.ip(ip) self.valid.port(port) self.valid.protocol(protocol) self.valid.hash(hash) except AssertionError as e: self.log.error(e) return False stmt = "INSERT INTO items (ip, port, protocol, note, hash) VALUES(?,?,?,?,?)" return self.execute_sql(stmt, (ip, port, protocol, note, hash)) def get_item(self, item_id): """ Get all items associated with an item_id. """ self.log.debug('Getting information for item {0}.'.format(item_id)) stmt = "SELECT * FROM items WHERE id=?" if self.execute_sql(stmt, (item_id,)) is True: return self.cur.fetchone() else: return {} def get_unique_hosts(self): """ Get unique hosts listed in the item database. """ self.log.debug('Getting unique hosts.') stmt = "SELECT DISTINCT ip FROM items ORDER BY ip" if self.execute_sql(stmt) is True: return [h['ip'] for h in self.cur.fetchall()] else: return [] def get_unique_ports(self): """ Get unique ports in the database. """ ports = {'tcp': [], 'udp': []} self.log.debug('Getting unique TCP ports from the database.') stmt = """SELECT DISTINCT(port) FROM items WHERE port != 0 AND protocol == 'tcp' ORDER BY port ASC""" if self.execute_sql(stmt) is True: ports['tcp'] = [h['port'] for h in self.cur.fetchall()] self.log.debug('Getting unique UDP ports from the database.') stmt = """SELECT DISTINCT(port) FROM items WHERE port != 0 AND protocol == 'udp' ORDER BY port ASC""" if self.execute_sql(stmt) is True: ports['udp'] = [h['port'] for h in self.cur.fetchall()] return ports def get_ports_by_ip(self, ip): """ Get unique TCP and UDP ports associated with an IP. """ ports = {} self.log.debug('Getting unique TCP ports for {0}.'.format(ip)) stmt = """SELECT DISTINCT(port) FROM items WHERE port != 0 AND protocol == 'tcp' AND ip == ? ORDER BY port ASC""" if self.execute_sql(stmt, (ip,)) is True: ports['tcp'] = [h['port'] for h in self.cur.fetchall()] self.log.debug('Getting unique UDP ports for {0}.'.format(ip)) stmt = """SELECT DISTINCT(port) FROM items WHERE port != 0 AND protocol == 'udp' AND ip == ? ORDER BY port ASC""" if self.execute_sql(stmt, (ip,)) is True: ports['udp'] = [h['port'] for h in self.cur.fetchall()] return ports def get_items_by_ip(self, ip): """ Get all items associated with a host. """ self.log.debug('Getting items for host {0}.'.format(ip)) stmt = "SELECT * FROM items WHERE ip=?" if self.execute_sql(stmt, (ip,)) is True: return self.cur.fetchall() else: return [] def get_items_by_hash(self, hash): """ Return a list of hosts with the specified hash. """ self.log.debug('Getting items associated with hash {0}.'.format(hash)) stmt = "SELECT ip FROM items WHERE hash=?" if self.execute_sql(stmt, (hash,)) is True: return [i['ip'] for i in self.cur.fetchall()] else: return [] def get_items_by_keywords(self, keywords): """ Return a list of items with the specified keywords. """ if keywords is None: return [] else: self.log.debug('Getting items associated with keywords {0}.'.format(','.join(keywords))) stmt = "SELECT id, ip, port FROM items WHERE " stmt += ' OR '.join(["note LIKE ?" for i in xrange(len(keywords))]) kw_strs = tuple(['%{0}%'.format(kw) for kw in keywords]) if self.execute_sql(stmt, kw_strs) is True: return [(i['id'], i['ip'], i['port']) for i in self.cur.fetchall()] else: return [] class AttackDatabase(Database): """ Class to handle attack data. """ def __init__(self, filename): Database.__init__(self, filename) attacks = ''' CREATE TABLE IF NOT EXISTS attacks ( id integer primary key autoincrement, name text, description text, items text, note text ) ''' ares = self.execute_sql(attacks) if ares is False: raise DatabaseException('Could not create attack table.') def create_attack(self, name, description, items): """ Create a new attack in the database. """ self.log.debug('Creating new attack for {0}.'.format(name)) stmt = "INSERT INTO attacks (name, description, items, note) VALUES(?,?,?,?)" return self.execute_sql(stmt, (name, description, ','.join(items), '')) def get_attack_by_name(self, name): """ Get an attack id by name. """ self.log.debug('Getting attack id for {0}.'.format(name)) stmt = "SELECT id, note FROM attacks WHERE name=?" if self.execute_sql(stmt, (name, ), commit=False) is True: return self.cur.fetchone() else: return None def get_attack(self, aid): """ Get an attack and a list of potential targets. """ self.log.debug('Getting attack {0}.'.format(aid)) stmt = "SELECT * FROM attacks WHERE id=?" if self.execute_sql(stmt, (aid,), commit=False) is True: return self.cur.fetchone() else: return None def get_attacks(self): """ Get all potential attacks. """ self.log.debug('Getting all potential attacks.') stmt = "SELECT id, name, description FROM attacks" if self.execute_sql(stmt, commit=False) is True: return self.cur.fetchall() else: return [] def get_attack_notes(self): """ Get all attack notes. """ self.log.debug('Getting notes for all attacks.') stmt = "SELECT name, note FROM attacks" if self.execute_sql(stmt, commit=False) is True: return [(a['name'], a['note']) for a in self.cur.fetchall()] else: return [] def update_attack_hosts(self, aid, items): """ Update the attack items. """ self.log.debug('Updating items for attack {0}.'.format(aid)) stmt = "UPDATE attacks SET items=? WHERE id=?" return self.execute_sql(stmt, (','.join(items), aid)) def update_attack_note(self, aid, note): """ Update the attack note. """ self.log.debug('Updating note for attack {0}.'.format(aid)) stmt = "UPDATE attacks SET note=? WHERE id=?" return self.execute_sql(stmt, (note, aid)) class HostDatabase(Database): """ Class to handle host data. """ def __init__(self, filename): Database.__init__(self, filename) hosts = ''' CREATE TABLE IF NOT EXISTS hosts ( id integer primary key autoincrement, ip text, os text, fqdn text, note text ) ''' hres = self.execute_sql(hosts) if hres is False: raise DatabaseException('Could not create hosts table.') def create_host(self, ip, os, fqdn): """ Create a new host identified by the IP address. """ self.log.debug('Creating new host for {0}.'.format(ip)) stmt = "INSERT INTO hosts (ip, os, fqdn) VALUES(?,?,?)" return self.execute_sql(stmt, (ip, os, fqdn)) def get_host(self, ip): """ Get host data associated with ip. """ self.log.debug('Getting host data for {0}.'.format(ip)) stmt = "SELECT ip, os, fqdn FROM hosts WHERE ip=?" if self.execute_sql(stmt, (ip,)) is True: return self.cur.fetchone() else: return {} def get_host_ip(self, ip): """ Return the host if it exists in the database. """ self.log.debug('Getting host record associated with IP {0}.'.format(ip)) stmt = "SELECT ip FROM hosts WHERE ip=? LIMIT=1" if self.execute_sql(stmt, (ip,)) is True: return [i['ip'] for i in self.cur.fetchall()] else: return [] def get_host_notes(self): """ Get all notes for hosts. """ self.log.debug('Getting all host notes.') stmt = "SELECT ip, note from hosts ORDER BY ip" if self.execute_sql(stmt) is True: return self.cur.fetchall() else: return [] def get_host_note(self, ip): """ Get notes for the specified host. """ self.log.debug('Getting notes for {0}.'.format(ip)) stmt = "SELECT note from hosts WHERE ip=?" if self.execute_sql(stmt, (ip,)) is True: return self.cur.fetchone()['note'] else: return "" def update_host_note(self, ip, note): """ Update the host note. """ self.log.debug('Updating note for host {0}.'.format(ip)) stmt = "UPDATE hosts SET note=? WHERE ip=?" return self.execute_sql(stmt, (note, ip)) class ImportDatabase(Database): """ Class to handle import data. """ def __init__(self, filename): Database.__init__(self, filename) imports = ''' CREATE TABLE IF NOT EXISTS imports ( id integer primary key autoincrement, filename text ) ''' ires = self.execute_sql(imports) if ires is False: raise DatabaseException('Could not create imports table.') def get_imported_files(self): """ Get all imported files for the specified project id. """ self.log.debug('Getting all imported files.') stmt = "SELECT filename FROM imports ORDER BY filename" if self.execute_sql(stmt) is True: return [p['filename'] for p in self.cur.fetchall()] else: return [] def add_import_file(self, filename): """ Add a filename to the table of imported files for a project. """ self.log.debug('Adding imported file {0}.'.format(filename)) stmt = "INSERT INTO imports (filename) VALUES (?)" return self.execute_sql(stmt, (filename,)) class ProjectDatabase(Database): """ Keep track of projects and the database names associated with them. """ def __init__(self): Database.__init__(self, PRJ_FILE) tables = self.get_tables() self.log.debug('TABLES: {0}'.format(tables)) if not ('projects' in tables): self.initialize_project_database() def initialize_project_database(self): """ Create a new project database. """ projects = ''' CREATE TABLE IF NOT EXISTS projects ( id integer primary key autoincrement, name text, note text, dbfile text ) ''' res = self.execute_sql(projects) if res is False: raise DatabaseException('Could not initialize project database.') def create_project(self, name): """ Add new project. """ self.log.debug('Creating new project.') db_name = ''.join([random.choice('0123456789abcdef') for _ in range(12)]) db_name = os.path.join('data', db_name + '.sqlite') try: scan_db = ScanDatabase(db_name) stmt = "INSERT INTO projects (name, dbfile) VALUES(?,?)" return self.execute_sql(stmt, (name, db_name)) except DatabaseException: return False def get_project(self, pid): """ Get the project name and database file associated with the pid. """ self.log.debug('Getting project for {0}.'.format(pid)) stmt = "SELECT name, dbfile, note FROM projects WHERE id=?" if self.execute_sql(stmt, (pid,)) is True: return self.cur.fetchone() else: return None, None def get_projects(self): """ Get all projects. """ self.log.debug('Getting all projects.') stmt = "SELECT * FROM projects ORDER BY name" if self.execute_sql(stmt) is True: return self.cur.fetchall() else: return [] def update_project_note(self, pid, note): """ Update the project notes. """ self.log.debug('Updating project {0}.'.format(pid)) stmt = "UPDATE projects SET note=? WHERE id=?" return self.execute_sql(stmt, (note, pid)) def delete_project(self, pid): """ Delete the project and the associated database file. """ self.log.debug('Deleting project {0}.'.format(pid)) name, db_file, _ = self.get_project(pid) if name is None: self.log.error('Could not find project {0}.'.format(pid)) stmt = "DELETE FROM projects WHERE id=?" if self.execute_sql(stmt, (pid,)) is True: self.delete_file(db_file) else: self.error('Could not delete project {0} from database.'.format(pid)) def delete_file(self, filename): """ Delete the specified file. """ try: os.remove(filename) except Exception as e: self.log.error('Could not delete file {0}: {1}'.format(filename, e))