""" SQL Database ============= SQLite database for metadata. .. autosummary:: :toctree: database Geometry Picks Client """ import os from operator import attrgetter from contextlib import contextmanager from sqlalchemy import create_engine, Column, Integer, BigInteger, ForeignKey, String, DateTime, Float, func from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker from seisnn.utils import get_config Base = declarative_base() class Inventory(Base): """Inventory table for sql database.""" __tablename__ = 'inventory' network = Column(String, nullable=False) station = Column(String, primary_key=True) latitude = Column(Float, nullable=False) longitude = Column(Float, nullable=False) elevation = Column(Float, nullable=False) def __init__(self, net, sta, loc): self.network = net self.station = sta self.latitude = loc['latitude'] self.longitude = loc['longitude'] self.elevation = loc['elevation'] def __repr__(self): return f"Geometry(" \ f"Network={self.network}, " \ f"Station={self.station}, " \ f"Latitude={self.latitude:>7.4f}, " \ f"Longitude={self.longitude:>8.4f}, " \ f"Elevation={self.elevation:>6.1f})" def add_db(self, session): session.add(self) class Event(Base): """Event table for sql database.""" __tablename__ = 'event' id = Column("id", BigInteger().with_variant(Integer, "sqlite"), primary_key=True) time = Column(DateTime, nullable=False) latitude = Column(Float, nullable=False) longitude = Column(Float, nullable=False) depth = Column(Float, nullable=False) def __init__(self, event): self.time = event.origins[0].time.datetime self.latitude = event.origins[0].latitude self.longitude = event.origins[0].longitude self.depth = event.origins[0].depth def __repr__(self): return f"Event(" \ f"Time={self.time}" \ f"Latitude={self.latitude:>7.4f}, " \ f"Longitude={self.longitude:>8.4f}, " \ f"Depth={self.depth:>6.1f})" def add_db(self, session): session.add(self) class Pick(Base): """Pick table for sql database.""" __tablename__ = 'pick' id = Column("id", BigInteger().with_variant(Integer, "sqlite"), primary_key=True) time = Column(DateTime, nullable=False) station = Column(String, ForeignKey('inventory.station'), nullable=False) phase = Column(String, nullable=False) tag = Column(String, nullable=False) snr = Column(Float) def __init__(self, pick, tag): self.time = pick.time.datetime self.station = pick.waveform_id.station_code self.phase = pick.phase_hint self.tag = tag def __repr__(self): return f"Pick(" \ f"Time={self.time}, " \ f"Station={self.station}, " \ f"Phase={self.phase}, " \ f"Tag={self.tag}, " \ f"SNR={self.snr})" def add_db(self, session): session.add(self) class TFRecord(Base): """TFRecord table for sql database.""" __tablename__ = 'tfrecord' id = Column("id", BigInteger().with_variant(Integer, "sqlite"), primary_key=True) file = Column(String) tag = Column(String) station = Column(String, ForeignKey('inventory.station')) def __init__(self, tfrecord): pass def __repr__(self): return f"TFRecord(" \ f"File={self.file}, " \ f"Tag={self.tag}, " \ f"Station={self.station})" def add_db(self, session): session.add(self) class Waveform(Base): """Waveform table for sql database.""" __tablename__ = 'waveform' id = Column("id", BigInteger().with_variant(Integer, "sqlite"), primary_key=True) starttime = Column(DateTime, nullable=False) endtime = Column(DateTime, nullable=False) station = Column(String, ForeignKey('inventory.station')) tfrecord = Column(String, ForeignKey('tfrecord.file')) def __init__(self, waveform): pass def __repr__(self): return f"Waveform(" \ f"Start Time={self.starttime}, " \ f"End Time={self.endtime}, " \ f"Station={self.station}, " \ f"TFRecord={self.tfrecord})" def add_db(self, session): session.add(self) class Client: """A client for manipulate sql database""" def __init__(self, database, echo=False): config = get_config() self.database = database db_path = os.path.join(config['DATABASE_ROOT'], self.database) self.engine = create_engine(f'sqlite:///{db_path}?check_same_thread=False', echo=echo) Base.metadata.create_all(bind=self.engine) self.session = sessionmaker(bind=self.engine) def read_hyp(self, hyp, network): """seisnn.io.read_hyp wrap up""" from seisnn.io import read_hyp geom = read_hyp(hyp) self.add_geom(geom, network) def read_kml_placemark(self, kml, network): """seisnn.io.read_kml_placemark wrap up""" from seisnn.io import read_kml_placemark geom = read_kml_placemark(kml) self.add_geom(geom, network) def add_geom(self, geom, network): with self.session_scope() as session: counter = 0 for sta, loc in geom.items(): Inventory(network, sta, loc).add_db(session) counter += 1 session.commit() print(f'Input {counter} stations') def get_geom(self, station=None, network=None): with self.session_scope() as session: query = session.query(Inventory) if station: station = self.replace_sql_wildcard(station) query = query.filter(Inventory.station.like(station)) if network: network = self.replace_sql_wildcard(network) query = query.filter(Inventory.network.like(network)) return query def geom_summery(self): with self.session_scope() as session: station = session.query(Inventory.station).order_by(Inventory.station) station_count = session.query(Inventory.station).count() print(f'Station name:') print([stat[0] for stat in station], '\n') print(f'Total {station_count} stations\n') boundary = session.query(func.min(Inventory.longitude), func.max(Inventory.longitude), func.min(Inventory.latitude), func.max(Inventory.latitude)).all() print(f'Station boundary:') print(f'West: {boundary[0][0]:>8.4f}') print(f'East: {boundary[0][1]:>8.4f}') print(f'South: {boundary[0][2]:>7.4f}') print(f'North: {boundary[0][3]:>7.4f}\n') def plot_map(self): from seisnn.plot import plot_map with self.session_scope() as session: geometry = session.query(Inventory.latitude, Inventory.longitude, Inventory.network).all() events = session.query(Event.latitude, Event.longitude).all() plot_map(geometry, events) def add_events(self, catalog, tag, remove_duplicates=True): from seisnn.io import read_event_list events = read_event_list(catalog) with self.session_scope() as session: event_count = 0 pick_count = 0 for event in events: Event(event).add_db(session) event_count += 1 for pick in event.picks: Pick(pick, tag).add_db(session) pick_count += 1 print(f'Input {event_count} events, {pick_count} picks') if remove_duplicates: self.remove_duplicates(Event, ['time', 'latitude', 'longitude', 'depth']) self.remove_duplicates(Pick, ['time', 'phase', 'station', 'tag']) def get_picks(self, starttime=None, endtime=None, station=None, phase=None, tag=None): with self.session_scope() as session: query = session.query(Pick) if starttime: query = query.filter(Pick.time >= starttime) if endtime: query = query.filter(Pick.time <= endtime) if station: station = self.replace_sql_wildcard(station) query = query.filter(Pick.station.like(station)) if phase: query = query.filter(Pick.phase.like(phase)) if tag: query = query.filter(Pick.tag.like(tag)) return query def event_summery(self): with self.session_scope() as session: time = session.query(func.min(Event.time), func.max(Event.time)).all() print(f'Event time duration:') print(f'From: {time[0][0].isoformat()}') print(f'To: {time[0][1].isoformat()}\n') event_count = session.query(Event).count() print(f'Total {event_count} events\n') boundary = session.query(func.min(Event.longitude), func.max(Event.longitude), func.min(Event.latitude), func.max(Event.latitude)).all() print(f'Event boundary:') print(f'West: {boundary[0][0]:>8.4f}') print(f'East: {boundary[0][1]:>8.4f}') print(f'South: {boundary[0][2]:>7.4f}') print(f'North: {boundary[0][3]:>7.4f}\n') self.pick_summery() def pick_summery(self): with self.session_scope() as session: time = session.query(func.min(Pick.time), func.max(Pick.time)).all() print(f'Pick time duration:') print(f'From: {time[0][0].isoformat()}') print(f'To: {time[0][1].isoformat()}\n') print(f'Phase count:') phase_group_count = session.query(Pick.phase, func.count(Pick.phase)) \ .group_by(Pick.phase).all() ps_picks = 0 for phase, count in phase_group_count: if phase in ['P', 'S']: ps_picks += count print(f'{count} "{phase}" picks') print(f'Total {ps_picks} P + S picks\n') station_count = session.query(Pick.station.distinct()).count() print(f'Picks cover {station_count} stations:') station = session.query(Pick.station.distinct()).order_by(Pick.station).all() print([stat[0] for stat in station], '\n') no_pick_station = session.query(Inventory.station) \ .order_by(Inventory.station) \ .filter(Inventory.station.notin_(session.query(Pick.station))).all() if no_pick_station: print(f'{len(no_pick_station)} stations without picks:') print([stat[0] for stat in no_pick_station], '\n') no_geom_station = session.query(Pick.station.distinct()) \ .order_by(Pick.station) \ .filter(Pick.station.notin_(session.query(Inventory.station))).all() if no_geom_station: print(f'{len(no_geom_station)} stations without geometry:') print([stat[0] for stat in no_geom_station], '\n') def generate_training_data(self, output): from functools import partial from seisnn.utils import make_dirs, parallel from seisnn.io import _write_picked_stream, write_tfrecord config = get_config() dataset_dir = os.path.join(config['TFRECORD_ROOT'], output) make_dirs(dataset_dir) par = partial(_write_picked_stream, database=self.database) station_list = self.list_distinct_items(Pick, 'station') for station in station_list: file_name = '{}.tfrecord'.format(station) picks = self.get_picks(station=station).all() example_list = parallel(par, picks) save_file = os.path.join(dataset_dir, file_name) write_tfrecord(example_list, save_file) print(f'{file_name} done') def remove_duplicates(self, table, match_columns: list): with self.session_scope() as session: attrs = attrgetter(*match_columns) table_columns = attrs(table) distinct = session.query(table, func.min(table.id)) \ .group_by(*table_columns) duplicate = session.query(table) \ .filter(table.id.notin_(distinct.with_entities(table.id))) \ .delete(synchronize_session='fetch') print(f'Remove {duplicate} duplicate {table.__tablename__}s') def list_distinct_items(self, table, column): with self.session_scope() as session: col = attrgetter(column) query = session.query(col(table).distinct()).order_by(col(table)).all() query = [q[0] for q in query] return query @contextmanager def session_scope(self): """Provide a transactional scope around a series of operations.""" session = self.session() try: yield session session.commit() except Exception as exception: print(f'{exception.__class__.__name__}: {exception.__cause__}') session.rollback() finally: session.close() @staticmethod def replace_sql_wildcard(string): string = string.replace('?', '_') string = string.replace('*', '%') return string if __name__ == "__main__": pass