from siuba.dply.verbs import ( singledispatch2, show_query, collect, simple_varname, select, VarList, var_select, mutate, transmute, filter, arrange, _call_strip_ascending, summarize, count, group_by, ungroup, case_when, join, left_join, right_join, inner_join, semi_join, anti_join, head, rename, distinct, if_else ) from .translate import sa_get_over_clauses, CustomOverClause, SqlColumn, SqlColumnAgg from .utils import get_dialect_funcs, get_sql_classes from sqlalchemy import sql import sqlalchemy from siuba.siu import Call, CallTreeLocal, str_to_getitem_call, Lazy, FunctionLookupError # TODO: currently needed for select, but can we remove pandas? from pandas import Series import pandas as pd from sqlalchemy.sql import schema # TODO: # - distinct # - annotate functions using sel.prefix_with("\n/*<Mutate Statement>*/\n") ? # Helpers --------------------------------------------------------------------- class SqlFunctionLookupError(FunctionLookupError): pass class CallListener: """Generic listener. Each exit is called on a node's copy.""" def enter(self, node): args, kwargs = node.map_subcalls(self.enter) return self.exit(node.__class__(node.func, *args, **kwargs)) def exit(self, node): return node class WindowReplacer(CallListener): """Call tree listener. Produces 2 important behaviors via the enter method: - returns evaluated sql call expression, with labels on all window expressions. - stores all labeled window expressions via the windows property. TODO: could replace with a sqlalchemy transformer """ def __init__(self, columns, group_by, order_by, window_cte = None): self.columns = columns self.group_by = group_by self.order_by = order_by self.window_cte = window_cte self.windows = [] def exit(self, node): col_expr = node(self.columns) if not isinstance(col_expr, sql.elements.ClauseElement): return col_expr over_clauses = [x for x in sa_get_over_clauses(col_expr) if isinstance(x, CustomOverClause)] # put groupings and orderings onto custom over clauses for over in over_clauses: # TODO: shouldn't mutate these over clauses group_by = sql.elements.ClauseList( *[self.columns[name] for name in self.group_by] ) order_by = sql.elements.ClauseList( *_create_order_by_clause(self.columns, *self.order_by) ) over.set_over(group_by, order_by) if len(over_clauses) and self.window_cte is not None: label = col_expr.label(None) self.windows.append(label) # optionally put into CTE, and return its resulting column self.window_cte.append_column(label) win_col = self.window_cte.c.values()[-1] return win_col return col_expr def track_call_windows(call, columns, group_by, order_by, window_cte = None): listener = WindowReplacer(columns, group_by, order_by, window_cte) col = listener.enter(call) return col, listener.windows def lift_inner_cols(tbl): cols = list(tbl.inner_columns) data = {col.key: col for col in cols} return sql.base.ImmutableColumnCollection(data, cols) def col_expr_requires_cte(call, sel, is_mutate = False): """Return whether a variable assignment needs a CTE""" call_vars = set(call.op_vars(attr_calls = False)) columns = lift_inner_cols(sel) sel_labs = set(k for k,v in columns.items() if isinstance(v, sql.elements.Label)) # I use the acronym fwg sol (frog soul) to remember sql clause eval order # from, where, group by, select, order by, limit # group clause evaluated before select clause, so not issue for mutate group_needs_cte = not is_mutate and len(sel._group_by_clause) return ( group_needs_cte or len(sel._order_by_clause) or not sel_labs.isdisjoint(call_vars) ) def get_missing_columns(call, columns): missing_cols = set(call.op_vars(attr_calls = False)) - set(columns.keys()) return missing_cols def compile_el(tbl, el): compiled = el.compile( dialect = tbl.source.dialect, compile_kwargs = {"literal_binds": True} ) return compiled # Misc utilities -------------------------------------------------------------- def ordered_union(x, y): dx = {el: True for el in x} dy = {el: True for el in y} return tuple({**dx, **dy}) # Table ----------------------------------------------------------------------- class LazyTbl: def __init__( self, source, tbl, columns = None, ops = None, group_by = tuple(), order_by = tuple(), funcs = None, rm_attr = ('str', 'dt'), call_sub_attr = ('dt', 'str'), dispatch_cls = None, result_cls = sql.elements.ClauseElement ): """Create a representation of a SQL table. Args: source: a sqlalchemy.Engine or sqlalchemy.Connection instance. tbl: table of form 'schema_name.table_name', 'table_name', or sqlalchemy.Table. columns: if specified, a listlike of column names. Examples -------- :: from sqlalchemy import create_engine from siuba.data import mtcars # create database and table engine = create_engine("sqlite:///:memory:") mtcars.to_sql('mtcars', engine) tbl_mtcars = LazyTbl(engine, 'mtcars') """ # connection and dialect specific functions self.source = sqlalchemy.create_engine(source) if isinstance(source, str) else source dialect = self.source.dialect.name self.funcs = get_dialect_funcs(dialect) if funcs is None else funcs self.dispatch_cls = get_sql_classes(dialect) if dispatch_cls is None else dispatch_cls self.result_cls = result_cls self.tbl = self._create_table(tbl, columns, self.source) # important states the query can be in (e.g. grouped) self.ops = [sql.Select([self.tbl])] if ops is None else ops self.group_by = group_by self.order_by = order_by # customizations to allow interop with pandas (e.g. handle dt methods) self.rm_attr = rm_attr self.call_sub_attr = call_sub_attr self.result_cls = result_cls def append_op(self, op, **kwargs): cpy = self.copy(**kwargs) cpy.ops = cpy.ops + [op] return cpy def copy(self, **kwargs): return self.__class__(**{**self.__dict__, **kwargs}) def shape_call( self, call, window = True, str_accessors = False, verb_name = None, arg_name = None, ): # TODO: error if mutate receives a literal value? # TODO: dispatch_cls currently unused if str_accessors and isinstance(call, str): # verbs that can use strings as accessors, like group_by, or # arrange, need to convert those strings into a getitem call return str_to_get_item_call(call) elif isinstance(call, sql.elements.ColumnClause): return Lazy(call) elif not isinstance(call, Call): # verbs that use literal strings, need to convert them to a call # that returns a sqlalchemy "literal" object return Lazy(sql.literal(call)) # set up locals funcs dict f_dict1 = self.funcs['scalar'] f_dict2 = self.funcs['window' if window else 'aggregate'] funcs = {**f_dict1, **f_dict2} # determine dispatch class cls_name = 'window' if window else 'aggregate' dispatch_cls = self.dispatch_cls[cls_name] call_shaper = CallTreeLocal( funcs, call_sub_attr = self.call_sub_attr, dispatch_cls = dispatch_cls, result_cls = self.result_cls ) # raise informative error message if missing translation try: return call_shaper.enter(call) except FunctionLookupError as err: raise SqlFunctionLookupError.from_verb( verb_name or "Unknown", arg_name or "Unknown", err, short = True ) def track_call_windows(self, call, columns = None, window_cte = None): """Returns tuple of (new column expression, list of window exprs)""" columns = self.last_op.columns if columns is None else columns return track_call_windows(call, columns, self.group_by, self.order_by, window_cte) def get_ordered_col_names(self): """Return columns from current select, with grouping columns first.""" ungrouped = [k for k in self.last_op.columns.keys() if k not in self.group_by] return list(self.group_by) + ungrouped @property def last_op(self): return self.ops[-1] if len(self.ops) else None @staticmethod def _create_table(tbl, columns = None, source = None): """Return a sqlalchemy.Table, autoloading column info if needed. Arguments: tbl: a sqlalchemy.Table or string of form 'table_name' or 'schema_name.table_name'. columns: a tuple of column names for the table. Overrides source argument. source: a sqlalchemy engine, used to autoload columns. """ if isinstance(tbl, sqlalchemy.Table): return tbl if not isinstance(tbl, str): raise ValueError("tbl must be a sqlalchemy Table or string, but was %s" %type(tbl)) if columns is None and source is None: raise ValueError("One of columns or source must be specified") schema, table_name = tbl.split('.') if '.' in tbl else [None, tbl] columns = map(sqlalchemy.Column, columns) if columns is not None else tuple() return sqlalchemy.Table( table_name, sqlalchemy.MetaData(bind = source), *columns, schema = schema, autoload_with = source if not columns else None ) def _get_preview(self): # need to make prev op a cte, so we don't override any previous limit new_sel = sql.select([self.last_op.alias()]).limit(5) tbl_small = self.append_op(new_sel) return collect(tbl_small) def __repr__(self): template = ( "# Source: lazy query\n" "# DB Conn: {}\n" "# Preview:\n{}\n" "# .. may have more rows" ) return template.format(repr(self.source.engine), repr(self._get_preview())) def _repr_html_(self): template = ( "<div>" "<pre>" "# Source: lazy query\n" "# DB Conn: {}\n" "# Preview:\n" "</pre>" "{}" "<p># .. may have more rows</p>" "</div>" ) data = self._get_preview() html_data = getattr(data, '_repr_html_', lambda: repr(data))() return template.format(self.source.engine, html_data) def _repr_grouped_df_html_(self): return "<div><p>(grouped data frame)</p>" + self._selected_obj._repr_html_() + "</div>" # Main Funcs # ============================================================================= # sql raw -------------- sql_raw = sql.literal_column # show query ----------- from sqlalchemy.ext.compiler import compiles, deregister from contextlib import contextmanager @contextmanager def use_simple_names(): get_col_name = lambda el, *args, **kwargs: str(el.element.name) try: yield compiles(sql.compiler._CompileLabel)(get_col_name) except: pass finally: deregister(sql.compiler._CompileLabel) @show_query.register(LazyTbl) def _show_query(tbl, simplify = False): query = tbl.last_op #if not simplify else compile_query = lambda: query.compile( dialect = tbl.source.dialect, compile_kwargs = {"literal_binds": True} ) if simplify: # try to strip table names and labels where uneccessary with use_simple_names(): print(compile_query()) else: # use a much more verbose query print(compile_query()) return tbl # collect ---------- @collect.register(LazyTbl) def _collect(__data, as_df = True): # TODO: maybe remove as_df options, always return dataframe # normally can just pass the sql objects to execute, but for some reason # psycopg2 completes about incomplete template. # see https://stackoverflow.com/a/47193568/1144523 query = __data.last_op compiled = query.compile( dialect = __data.source.dialect, compile_kwargs = {"literal_binds": True} ) if as_df: return pd.read_sql(compiled, __data.source) return __data.source.execute(compiled).fetchall() @select.register(LazyTbl) def _select(__data, *args, **kwargs): # see https://stackoverflow.com/questions/25914329/rearrange-columns-in-sqlalchemy-select-object if kwargs: raise NotImplementedError( "Using kwargs in select not currently supported. " "Use _.newname == _.oldname instead" ) last_op = __data.last_op columns = {c.key: c for c in last_op.inner_columns} # same as for DataFrame colnames = Series(list(columns)) vl = VarList() evaluated = (arg(vl) if callable(arg) else arg for arg in args) od = var_select(colnames, *evaluated) col_list = [] for k,v in od.items(): col = columns[k] col_list.append(col if v is None else col.label(v)) return __data.append_op(last_op.with_only_columns(col_list)) @filter.register(LazyTbl) def _filter(__data, *args): # TODO: aggregate funcs # Note: currently always produces 2 additional select statements, # 1 for window/aggs, and 1 for the where clause sel = __data.last_op.alias() # original select win_sel = sql.select([sel], from_obj = sel) # first cte conds = [] windows = [] for ii, arg in enumerate(args): if isinstance(arg, Call): new_call = __data.shape_call(arg, verb_name = "Filter", arg_name = ii) #var_cols = new_call.op_vars(attr_calls = False) col_expr, win_cols = __data.track_call_windows( new_call, sel.columns, window_cte = win_sel ) conds.append(col_expr) else: conds.append(arg) bool_clause = sql.and_(*conds) # move non-window functions to refer to win_sel clause (not the innermost) --- win_alias = win_sel.alias() bool_clause = sql.util.ClauseAdapter(win_alias).traverse(bool_clause) # create second cte ---- orig_cols = [win_alias.columns[k] for k in sel.columns.keys()] filt_sel = sql.select(orig_cols, from_obj = win_alias, whereclause = bool_clause) return __data.append_op(filt_sel) @mutate.register(LazyTbl) def _mutate(__data, **kwargs): # Cases # - work with group by # - window functions # TODO: verify it can follow a renaming select # track labeled columns in set sel = __data.last_op # evaluate each call for colname, func in kwargs.items(): # keep set of columns labeled (aliased) in this select statement # need to use inner cols, since sel.columns uses ColumnClause, not Label labs = set(k for k,v in lift_inner_cols(sel).items() if isinstance(v, sql.elements.Label)) new_call = __data.shape_call(func, verb_name = "Mutate", arg_name = colname) sel = _mutate_select(sel, colname, new_call, labs, __data) return __data.append_op(sel) def _mutate_select(sel, colname, func, labs, __data): """Return select statement containing a new column, func expr as colname. Note: since a func can refer to previous columns generated by mutate, this function handles whether to add a column to the existing select statement, or to use it as a subquery. """ replace_col = colname in sel.columns # Call objects let us check whether column expr used a derived column # e.g. SELECT a as b, b + 1 as c raises an error in SQL, so need subquery if not col_expr_requires_cte(func, sel, is_mutate = True): # New column may be able to modify existing select columns = lift_inner_cols(sel) else: # anything else requires a subquery cte = sel.alias(None) columns = cte.columns sel = sql.select([cte], from_obj = cte) # evaluate call expr on columns, making sure to use group vars new_col, windows = __data.track_call_windows(func, columns) # replacing an existing column, so strip it from select statement if replace_col: replaced = {**columns} replaced[colname] = new_col.label(colname) return sel.with_only_columns(list(replaced.values())) return sel.column(new_col.label(colname)) @transmute.register(LazyTbl) def _transmute(__data, **kwargs): # will use mutate, then select some cols f_mutate = mutate.registry[type(__data)] # transmute keeps grouping cols, and any defined in kwargs cols_to_keep = ordered_union(__data.group_by, kwargs) sel = f_mutate(__data, **kwargs).last_op columns = lift_inner_cols(sel) sel_stripped = sel.with_only_columns([columns[k] for k in cols_to_keep]) return __data.append_op(sel_stripped) @arrange.register(LazyTbl) def _arrange(__data, *args): last_op = __data.last_op cols = lift_inner_cols(last_op) new_calls = [] for ii, expr in enumerate(args): if callable(expr): res = __data.shape_call( expr, window = False, verb_name = "Arrange", arg_name = ii ) else: res = expr new_calls.append(res) sort_cols = _create_order_by_clause(cols, *new_calls) order_by = __data.order_by + tuple(new_calls) return __data.append_op(last_op.order_by(*sort_cols), order_by = order_by) # TODO: consolidate / pull expr handling funcs into own file? def _create_order_by_clause(columns, *args): sort_cols = [] for arg in args: # simple named column if isinstance(arg, str): sort_cols.append(columns[arg]) # an expression elif callable(arg): #f, asc = _call_strip_ascending(arg) #col_op = f(cols) if asc else f(cols).desc() col_op = arg(columns) sort_cols.append(col_op) else: raise NotImplementedError("Must be string or callable") return sort_cols @count.register(LazyTbl) def _count(__data, *args, sort = False, wt = None, **kwargs): # TODO: if already col named n, use name nn, etc.. get logic from tidy.py if wt is not None: raise NotImplementedError("TODO") res_name = "n" # similar to filter verb, we need two select statements, # an inner one for derived cols, and outer to group by them # inner select ---- # holds any mutation style columns arg_names = [] for arg in args: name = simple_varname(arg) if name is None: raise NotImplementedError( "Count positional arguments must be single column name. " "Use a named argument to count using complex expressions." ) arg_names.append(name) tbl_inner = mutate(__data, **kwargs) sel_inner = tbl_inner.last_op group_cols = arg_names + list(kwargs) # outer select ---- # holds selected columns and tally (n) sel_inner_cte = sel_inner.alias() inner_cols = sel_inner_cte.columns sel_outer = sql.select(from_obj = sel_inner_cte) # apply any group vars from a group_by verb call first prev_group_cols = [inner_cols[k] for k in tbl_inner.group_by] if prev_group_cols: sel_outer.append_group_by(*prev_group_cols) sel_outer.append_column(*prev_group_cols) # now any defined in the count verb call for k in group_cols: sel_outer.append_group_by(inner_cols[k]) sel_outer.append_column(inner_cols[k]) count_col = sql.functions.count().label(res_name) sel_outer.append_column(count_col) # count is like summarize, so removes order_by return tbl_inner.append_op( sel_outer.order_by(count_col.desc()), order_by = tuple() ) @summarize.register(LazyTbl) def _summarize(__data, **kwargs): # https://stackoverflow.com/questions/14754994/why-is-sqlalchemy-count-much-slower-than-the-raw-query # what if windowed mutate or filter has been done? # - filter is fine, since it uses a CTE # - need to detect any window functions... sel = __data.last_op._clone() new_calls = {} for k, expr in kwargs.items(): new_calls[k] = __data.shape_call( expr, window = False, verb_name = "Summarize", arg_name = k ) needs_cte = [col_expr_requires_cte(call, sel) for call in new_calls.values()] # create select statement ---- if any(needs_cte): # need a cte, due to alias cols or existing group by # current select stmt has group by clause, so need to make it subquery cte = sel.alias() columns = cte.columns sel = sql.select(from_obj = cte) else: # otherwise, can alter the existing select statement columns = lift_inner_cols(sel) old_froms = sel.froms sel = sel.with_only_columns([]) sel.append_from(*old_froms) # add group by columns ---- group_cols = [columns[k] for k in __data.group_by] sel.append_group_by(*group_cols) for col in group_cols: sel.append_column(col) # add each aggregate column ---- # TODO: can't do summarize(b = mean(a), c = b + mean(a)) # since difficult for c to refer to agg and unagg cols in SQL for k, expr in new_calls.items(): missing_cols = get_missing_columns(expr, columns) if missing_cols: raise NotImplementedError( "Summarize cannot find the following columns: %s. " "Note that it cannot refer to variables defined earlier in the " "same summarize call." % missing_cols ) col = expr(columns).label(k) sel.append_column(col) new_data = __data.append_op(sel, group_by = tuple(), order_by = tuple()) return new_data @group_by.register(LazyTbl) def _group_by(__data, *args, add = False, **kwargs): if kwargs: data = mutate(__data, **kwargs) else: data = __data cols = data.last_op.columns # put kwarg grouping vars last, so similar order to function call groups = tuple(simple_varname(arg) for arg in args) + tuple(kwargs) if None in groups: raise NotImplementedError("Complex expressions not supported in sql group_by") unmatched = set(groups) - set(cols.keys()) if unmatched: raise KeyError("group_by specifies columns missing from table: %s" %unmatched) if add: groups = ordered_union(data.group_by, groups) return data.copy(group_by = groups) @ungroup.register(LazyTbl) def _ungroup(__data): return __data.copy(group_by = tuple()) @case_when.register(sql.base.ImmutableColumnCollection) def _case_when(__data, cases): # TODO: will need listener to enter case statements, to handle when they use windows if isinstance(cases, Call): cases = cases(__data) whens = [] case_items = list(cases.items()) n_items = len(case_items) else_val = None for ii, (expr, val) in enumerate(case_items): # handle where val is a column expr if callable(val): val = val(__data) # handle when expressions if ii+1 == n_items and expr is True: else_val = val elif callable(expr): whens.append((expr(__data), val)) else: whens.append((expr, val)) return sql.case(whens, else_ = else_val) # Join ------------------------------------------------------------------------ from collections.abc import Mapping def _joined_cols(left_cols, right_cols, on_keys, full = False): """Return labeled columns, according to selection rules for joins. Rules: 1. For join keys, keep left table's column 2. When keys have the same labels, add suffix """ # TODO: remove sets, so uses stable ordering # when left and right cols have same name, suffix with _x / _y keep_right = set(right_cols.keys()) - set(on_keys.values()) shared_labs = set(left_cols.keys()).intersection(keep_right) right_cols_no_keys = {k: right_cols[k] for k in keep_right} # for an outer join, have key columns coalesce values if full: left_cols = {**left_cols} for lk, rk in on_keys.items(): col = sql.functions.coalesce(left_cols[lk], right_cols[rk]) left_cols[lk] = col.label(lk) # create labels ---- l_labs = _relabeled_cols(left_cols, shared_labs, "_x") r_labs = _relabeled_cols(right_cols_no_keys, shared_labs, "_y") return l_labs + r_labs def _relabeled_cols(columns, keys, suffix): # add a suffix to all columns with names in keys cols = [] for k, v in columns.items(): new_col = v.label(k + str(suffix)) if k in keys else v cols.append(new_col) return cols @join.register(LazyTbl) def _join(left, right, on = None, *args, how = "inner", sql_on = None): _raise_if_args(args) # Needs to be on the table, not the select left_sel = left.last_op.alias() right_sel = right.last_op.alias() # handle arguments ---- on = _validate_join_arg_on(on, sql_on) how = _validate_join_arg_how(how) # for equality join used to combine keys into single column consolidate_keys = on if sql_on is None else {} if how == "right": # switch joins, since sqlalchemy doesn't have right join arg # see https://stackoverflow.com/q/11400307/1144523 left_sel, right_sel = right_sel, left_sel # create join conditions ---- bool_clause = _create_join_conds(left_sel, right_sel, on) # create join ---- join = left_sel.join( right_sel, onclause = bool_clause, isouter = how != "inner", full = how == "full" ) # note, shared_keys assumes on is a mapping... # TODO: shared_keys appears to be for when on is not specified, but was unused #shared_keys = [k for k,v in on.items() if k == v] labeled_cols = _joined_cols( left_sel.columns, right_sel.columns, on_keys = consolidate_keys, full = how == "full" ) sel = sql.select(labeled_cols, from_obj = join) return left.append_op(sel) @semi_join.register(LazyTbl) def _semi_join(left, right = None, on = None, *args, sql_on = None): _raise_if_args(args) left_sel = left.last_op.alias() right_sel = right.last_op.alias() # handle arguments ---- on = _validate_join_arg_on(on, sql_on) # create join conditions ---- bool_clause = _create_join_conds(left_sel, right_sel, on) # create inner join ---- exists_clause = sql.select( [sql.literal(1)], from_obj = right_sel, whereclause = bool_clause ) # only keep left hand select's columns ---- sel = sql.select( left_sel.columns, from_obj = left_sel, whereclause = sql.exists(exists_clause) ) return left.append_op(sel) @anti_join.register(LazyTbl) def _anti_join(left, right = None, on = None, *args, sql_on = None): _raise_if_args(args) left_sel = left.last_op.alias() right_sel = right.last_op.alias() # handle arguments ---- on = _validate_join_arg_on(on, sql_on) # create join conditions ---- bool_clause = _create_join_conds(left_sel, right_sel, on) # create inner join ---- not_exists = ~sql.exists([1], from_obj = right_sel).where(bool_clause) sel = sql.select(left_sel.columns, from_obj = left_sel).where(not_exists) return left.append_op(sel) def _raise_if_args(args): if len(args): raise NotImplemented("*args is reserved for future arguments (e.g. suffix)") def _validate_join_arg_on(on, sql_on = None): # handle sql on case if sql_on is not None: if on is not None: raise ValueError("Cannot specify both on and sql_on") return sql_on # handle general cases if on is None: raise NotImplementedError("on arg currently cannot be None (default) for SQL") elif isinstance(on, str): on = {on: on} elif isinstance(on, (list, tuple)): on = dict(zip(on, on)) if not isinstance(on, Mapping): raise TypeError("on must be a Mapping (e.g. dict)") return on def _validate_join_arg_how(how): how_options = ("inner", "left", "right", "full") if how not in how_options: raise ValueError("how argument needs to be one of %s" %how_options) return how def _create_join_conds(left_sel, right_sel, on): left_cols = left_sel.columns #lift_inner_cols(left_sel) right_cols = right_sel.columns #lift_inner_cols(right_sel) if callable(on): # callable, like with sql_on arg conds = [on(left_cols, right_cols)] else: # dict-like of form {left: right} conds = [] for l, r in on.items(): col_expr = left_cols[l] == right_cols[r] conds.append(col_expr) return sql.and_(*conds) # Head ------------------------------------------------------------------------ @head.register(LazyTbl) def _head(__data, n = 5): sel = __data.last_op return __data.append_op(sel.limit(n)) # Rename ---------------------------------------------------------------------- @rename.register(LazyTbl) def _rename(__data, **kwargs): sel = __data.last_op columns = lift_inner_cols(sel) # old_keys uses dict as ordered set old_to_new = {simple_varname(v):k for k,v in kwargs.items()} if None in old_to_new: raise KeyError("positional arguments must be simple column, " "e.g. _.colname or _['colname']" ) labs = [c.label(old_to_new[k]) if k in old_to_new else c for k,c in columns.items()] new_sel = sel.with_only_columns(labs) return __data.append_op(new_sel) # Distinct -------------------------------------------------------------------- @distinct.register(LazyTbl) def _distinct(__data, *args, _keep_all = False, **kwargs): if (args or kwargs) and _keep_all: raise NotImplementedError("Distinct with variables specified in sql requires _keep_all = False") inner_sel = mutate(__data, **kwargs).last_op if kwargs else __data.last_op # TODO: this is copied from the df distinct version # cols dict below is used as ordered set cols = {simple_varname(x): True for x in args} cols.update(kwargs) if None in cols: raise KeyError("positional arguments must be simple column, " "e.g. _.colname or _['colname']" ) # use all columns by default if not cols: cols = list(inner_sel.columns.keys()) if not len(inner_sel._order_by_clause): # select distinct has to include any columns in the order by clause, # so can only safely modify existing statement when there's no order by sel_cols = lift_inner_cols(inner_sel) distinct_cols = [sel_cols[k] for k in cols] sel = inner_sel.with_only_columns(distinct_cols).distinct() else: # fallback to cte cte = inner_sel.alias() distinct_cols = [cte.columns[k] for k in cols] sel = sql.select(distinct_cols, from_obj = cte).distinct() return __data.append_op(sel) # if_else --------------------------------------------------------------------- @if_else.register(sql.elements.ColumnElement) def _if_else(cond, true_vals, false_vals): whens = [(cond, true_vals)] return sql.case(whens, else_ = false_vals)