from functools import partial from collections import defaultdict import sqlalchemy from ..types import String, Integer from ..graph import Nothing, Maybe, One, Many from ..engine import pass_context def _translate_type(column): if isinstance(column.type, sqlalchemy.Integer): return Integer elif isinstance(column.type, sqlalchemy.Unicode): return String else: return None def _table_repr(table): return 'Table({})'.format(', '.join( [repr(table.name), repr(table.metadata), '...', 'schema={!r}'.format(table.schema)] )) @pass_context class FieldsQuery: def __init__(self, engine_key, from_clause, *, primary_key=None): self.engine_key = engine_key self.from_clause = from_clause if primary_key is not None: self.primary_key = primary_key else: # currently only one column supported self.primary_key, = from_clause.primary_key def __repr__(self): if isinstance(self.from_clause, sqlalchemy.Table): from_clause_repr = _table_repr(self.from_clause) else: from_clause_repr = repr(self.from_clause) return ('<{}.{}: engine_key={!r}, from_clause={}, primary_key={!r}>' .format(self.__class__.__module__, self.__class__.__name__, self.engine_key, from_clause_repr, self.primary_key)) def __postprocess__(self, field): if field.type is None: column = self.from_clause.c[field.name] field.type = _translate_type(column) def in_impl(self, column, values): return column.in_(values) def select_expr(self, fields_, ids): columns = [self.from_clause.c[f.name] for f in fields_] expr = ( sqlalchemy.select([self.primary_key] + columns) .select_from(self.from_clause) .where(self.in_impl(self.primary_key, ids)) ) def result_proc(rows): rows_map = {row[self.primary_key]: [row[c] for c in columns] for row in rows} nulls = [None for _ in fields_] return [rows_map.get(id_, nulls) for id_ in ids] return expr, result_proc def __call__(self, ctx, fields_, ids): if not ids: return [] expr, result_proc = self.select_expr(fields_, ids) sa_engine = ctx[self.engine_key] with sa_engine.connect() as connection: rows = connection.execute(expr).fetchall() return result_proc(rows) def _to_maybe_mapper(pairs, values): mapping = dict(pairs) return [mapping.get(value, Nothing) for value in values] def _to_one_mapper(pairs, values): mapping = dict(pairs) return [mapping[value] for value in values] def _to_many_mapper(pairs, values): mapping = defaultdict(list) for from_value, to_value in pairs: mapping[from_value].append(to_value) return [mapping[value] for value in values] class LinkQuery: def __init__(self, engine_key, *, from_column, to_column): if from_column.table is not to_column.table: raise ValueError('from_column and to_column should belong to ' 'one table') self.engine_key = engine_key self.from_column = from_column self.to_column = to_column def __repr__(self): return ('<{}.{}: engine_key={!r}, from_column={!r}, to_column={!r}>' .format(self.__class__.__module__, self.__class__.__name__, self.engine_key, self.from_column, self.to_column)) def __postprocess__(self, link): if link.type_enum is One: func = partial(self, _to_one_mapper) elif link.type_enum is Maybe: func = partial(self, _to_maybe_mapper) elif link.type_enum is Many: func = partial(self, _to_many_mapper) else: raise TypeError(repr(link.type_enum)) link.func = pass_context(func) def in_impl(self, column, values): return column.in_(values) def select_expr(self, ids): # TODO: make this optional, but enabled by default filtered_ids = [i for i in set(ids) if i is not None] if filtered_ids: return ( sqlalchemy.select([self.from_column.label('from_column'), self.to_column.label('to_column')]) .where(self.in_impl(self.from_column, filtered_ids)) ) else: return None def __call__(self, result_proc, ctx, ids): expr = self.select_expr(ids) if expr is None: pairs = [] else: sa_engine = ctx[self.engine_key] with sa_engine.connect() as connection: pairs = connection.execute(expr).fetchall() return result_proc(pairs, ids)