# -*- coding: utf-8 -*- # Copyright 2014-2015 Michael Helmling # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License version 3 as # published by the Free Software Foundation from __future__ import division, unicode_literals, print_function import json from collections import OrderedDict import sqlalchemy as sqla from sqlalchemy.sql import expression import lpdec from lpdec.persistence import JSONDecodable from lpdec import simulation, utils, database as db initialized = False simTable = joinTable = None def init(): """Initialize the simulations database module. This needs to be called before any other function of this module can be used, but after :func:`db.init`. """ global simTable, joinTable, initialized if initialized: return if not db.initialized: raise RuntimeError('database.init() needs to be called before simulation.init()') simTable = sqla.Table('simulations', db.metadata, sqla.Column('id', sqla.Integer, primary_key=True), sqla.Column('identifier', sqla.String(128)), sqla.Column('code', None, sqla.ForeignKey('codes.id')), sqla.Column('decoder', None, sqla.ForeignKey('decoders.id')), sqla.Column('channel_class', sqla.String(16)), sqla.Column('snr', sqla.Float), sqla.Column('channel_json', sqla.Text), sqla.Column('wordSeed', sqla.Integer), sqla.Column('samples', sqla.Integer), sqla.Column('errors', sqla.Integer), sqla.Column('cputime', sqla.Float), sqla.Column('stats', sqla.Text), sqla.Column('date_start', db.UTCDateTime), sqla.Column('date_end', db.UTCDateTime), sqla.Column('machine', sqla.Text), sqla.Column('program_name', sqla.String(64)), sqla.Column('program_version', sqla.String(64))) db.metadata.create_all(db.engine) joinTable = simTable.join(db.codesTable).join(db.decodersTable) initialized = True def teardown(): global simTable, joinTable, initialized simTable = joinTable = None initialized = False def existingIdentifiers(): """Returns a list of all identifiers for which simulation results exist in the database.""" s = sqla.select([simTable.c.identifier]).distinct() results = db.engine.execute(s).fetchall() return [r[0] for r in results] def addDataPoint(point): """Add (or update) a data point to the database. :param simulation.DataPoint point: DataPoint instance. .""" point.checkResume() codeId = db.checkCode(point.code, insert=True) decoderId = db.checkDecoder(point.decoder, insert=True) channelJSON = point.channel.toJSON() channelClass = type(point.channel).__name__ whereClause = ( (simTable.c.identifier == point.identifier) & (simTable.c.code == codeId) & (simTable.c.decoder == decoderId) & (simTable.c.channel_json == channelJSON) & (simTable.c.wordSeed == point.wordSeed) ) if utils.machineString() not in point.machine: point.machine = '{}/{}'.format(point.machine, utils.machineString()) values = dict(code=codeId, decoder=decoderId, identifier=point.identifier, channel_class=channelClass, snr=point.channel.snr, channel_json=channelJSON, wordSeed=point.wordSeed, samples=point.samples, errors=point.errors, cputime=point.cputime, date_start=point.date_start, date_end=point.date_end, machine=point.machine, program_name='lpdec', program_version=lpdec.exactVersion(), stats=json.dumps(point.stats, sort_keys=True)) result = db.engine.execute(sqla.select([simTable.c.id], whereClause)).fetchall() if len(result) > 0: assert len(result) == 1 simId = result[0][0] update = simTable.update().where(simTable.c.id == simId).values(**values) db.engine.execute(update) else: insert = simTable.insert().values(**values) db.engine.execute(insert) def dataPoint(code, channel, wordSeed, decoder, identifier): """Return a :class:`simulation.DataPoint` object for the given parameters. If such one exists in the database, it is initialized with the data (samples, errors etc.) from there. Otherwise an empty point is created. """ s = sqla.select([joinTable], (simTable.c.identifier == identifier) & (db.codesTable.c.name == code.name) & (db.decodersTable.c.name == decoder.name) & (simTable.c.channel_json == channel.toJSON()) & (simTable.c.wordSeed == wordSeed) ) ans = db.engine.execute(s).fetchone() point = simulation.DataPoint(code, channel, wordSeed, decoder, identifier) if ans is not None: point.samples = point._dbSamples = ans[simTable.c.samples] point.errors = ans[simTable.c.errors] point.cputime = point._dbCputime = ans[simTable.c.cputime] point.date_start = ans[simTable.c.date_start] point.date_end = ans[simTable.c.date_end] point.stats = json.loads(ans[simTable.c.stats]) point.version = ans[simTable.c.program_version] return point def dataPointFromRow(row): code = db.get('code', row[db.codesTable.c.name]) channel = JSONDecodable.fromJSON(row[simTable.c.channel_json]) wordSeed = row[simTable.c.wordSeed] decoder = db.get('decoder', row[db.decodersTable.c.name], code=code) identifier = row[simTable.c.identifier] point = simulation.DataPoint(code, channel, wordSeed, decoder, identifier) point.samples = point._dbSamples = row[simTable.c.samples] point.errors = row[simTable.c.errors] point.cputime = row[simTable.c.cputime] point.date_start = row[simTable.c.date_start] point.date_end = row[simTable.c.date_end] point.stats = json.loads(row[simTable.c.stats], object_pairs_hook=OrderedDict) point.version = row[simTable.c.program_version] point.program = row[simTable.c.program_name] point.machine = row[simTable.c.machine] return point def search(what, **conditions): if what == 'codename': columns = [db.codesTable.c.name] elif what == 'point': columns = [simTable.c.identifier, db.codesTable.c.name, db.decodersTable.c.name, simTable.c.channel_json, simTable.c.wordSeed, simTable.c.samples, simTable.c.errors, simTable.c.cputime, simTable.c.date_start, simTable.c.date_end, simTable.c.machine, simTable.c.program_name, simTable.c.program_version, simTable.c.stats] else: raise ValueError('unknown search: "{}"'.format(what)) condition = expression.true() for key, val in conditions.items(): if key == 'identifier': condition &= simTable.c.identifier.in_(val) elif key == 'code': condition &= db.codesTable.c.name.in_(val) else: raise ValueError() s = sqla.select(columns, whereclause=condition, from_obj=joinTable, distinct=True, use_labels=True).order_by(db.codesTable.c.name) ans = db.engine.execute(s).fetchall() if what == 'point': return [dataPointFromRow(row) for row in ans] return db.engine.execute(s).fetchall() def simulations(**conditions): points = search('point', **conditions) sims = {} for point in points: identifier = (point.code.name, point.decoder.name, point.channel.__class__, point.identifier, point.wordSeed, point.program) if identifier not in sims: sims[identifier] = simulation.Simulation() sims[identifier].add(point) return list(sims.values())