import contextlib
import functools
import operator
from typing import List, Optional

import pandas as pd
import sqlalchemy as sa
import sqlalchemy.sql as sql
from pkg_resources import parse_version
from sqlalchemy.dialects.mysql.base import MySQLDialect
from sqlalchemy.dialects.postgresql.base import PGDialect as PostgreSQLDialect
from sqlalchemy.dialects.sqlite.base import SQLiteDialect
from sqlalchemy.engine.interfaces import Dialect as SQLAlchemyDialect

import ibis
import ibis.common.exceptions as com
import ibis.expr.analysis as L
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.schema as sch
import ibis.expr.types as ir
import ibis.expr.window as W
import ibis.sql.compiler as comp
import ibis.sql.transforms as transforms
import ibis.util as util
from ibis.client import Database, Query, SQLClient
from ibis.sql.compiler import Dialect, Select, TableSetFormatter, Union

geospatial_supported = False
    import geoalchemy2 as ga
    import geoalchemy2.shape as shape
    import geopandas

    geospatial_supported = True
except ImportError:

# TODO(cleanup)
_ibis_type_to_sqla = {
    dt.Null: sa.types.NullType,
    dt.Date: sa.Date,
    dt.Time: sa.Time,
    dt.Boolean: sa.Boolean,
    dt.Binary: sa.Binary,
    dt.String: sa.Text,
    dt.Decimal: sa.NUMERIC,
    # Mantissa-based
    dt.Float: sa.Float(precision=24),
    dt.Double: sa.Float(precision=53),
    dt.Int8: sa.SmallInteger,
    dt.Int16: sa.SmallInteger,
    dt.Int32: sa.Integer,
    dt.Int64: sa.BigInteger,

def _to_sqla_type(itype, type_map=None):
    if type_map is None:
        type_map = _ibis_type_to_sqla
    if isinstance(itype, dt.Decimal):
        return sa.types.NUMERIC(itype.precision, itype.scale)
    elif isinstance(itype, dt.Date):
        return sa.Date()
    elif isinstance(itype, dt.Timestamp):
        # SQLAlchemy DateTimes do not store the timezone, just whether the db
        # supports timezones.
        return sa.TIMESTAMP(bool(itype.timezone))
    elif isinstance(itype, dt.Array):
        ibis_type = itype.value_type
        if not isinstance(ibis_type, (dt.Primitive, dt.String)):
            raise TypeError(
                'Type {} is not a primitive type or string type'.format(
        return sa.ARRAY(_to_sqla_type(ibis_type, type_map=type_map))
    elif geospatial_supported and isinstance(itype, dt.GeoSpatial):
        if itype.geotype == 'geometry':
            return ga.Geometry
        elif itype.geotype == 'geography':
            return ga.Geography
            return ga.types._GISType
        return type_map[type(itype)]

@dt.dtype.register(SQLAlchemyDialect, sa.types.NullType)
def sa_null(_, satype, nullable=True):
    return dt.null

@dt.dtype.register(SQLAlchemyDialect, sa.types.Boolean)
def sa_boolean(_, satype, nullable=True):
    return dt.Boolean(nullable=nullable)

@dt.dtype.register(MySQLDialect, sa.dialects.mysql.NUMERIC)
def sa_mysql_numeric(_, satype, nullable=True):
    return dt.Decimal(
        satype.precision or 10, satype.scale or 0, nullable=nullable

@dt.dtype.register(PostgreSQLDialect, sa.dialects.postgresql.NUMERIC)
def sa_postgres_numeric(_, satype, nullable=True):
    # PostgreSQL allows any precision for numeric values if not specified,
    # up to the implementation limit. Here, default to the maximum value that
    # can be specified by the user. The scale defaults to zero.
    return dt.Decimal(
        satype.precision or 1000, satype.scale or 0, nullable=nullable

@dt.dtype.register(SQLAlchemyDialect, sa.types.Numeric)
@dt.dtype.register(SQLiteDialect, sa.dialects.sqlite.NUMERIC)
def sa_numeric(_, satype, nullable=True):
    return dt.Decimal(satype.precision, satype.scale, nullable=nullable)

@dt.dtype.register(SQLAlchemyDialect, sa.types.SmallInteger)
def sa_smallint(_, satype, nullable=True):
    return dt.Int16(nullable=nullable)

@dt.dtype.register(SQLAlchemyDialect, sa.types.Integer)
def sa_integer(_, satype, nullable=True):
    return dt.Int32(nullable=nullable)

@dt.dtype.register(SQLAlchemyDialect, sa.dialects.mysql.TINYINT)
def sa_mysql_tinyint(_, satype, nullable=True):
    return dt.Int8(nullable=nullable)

@dt.dtype.register(SQLAlchemyDialect, sa.types.BigInteger)
def sa_bigint(_, satype, nullable=True):
    return dt.Int64(nullable=nullable)

@dt.dtype.register(SQLAlchemyDialect, sa.types.Float)
def sa_float(_, satype, nullable=True):
    return dt.Float(nullable=nullable)

@dt.dtype.register(SQLiteDialect, sa.types.Float)
@dt.dtype.register(PostgreSQLDialect, sa.dialects.postgresql.DOUBLE_PRECISION)
def sa_double(_, satype, nullable=True):
    return dt.Double(nullable=nullable)

@dt.dtype.register(PostgreSQLDialect, sa.dialects.postgresql.UUID)
def sa_uuid(_, satype, nullable=True):
    return dt.UUID(nullable=nullable)

@dt.dtype.register(PostgreSQLDialect, sa.dialects.postgresql.JSON)
def sa_json(_, satype, nullable=True):
    return dt.JSON(nullable=nullable)

@dt.dtype.register(PostgreSQLDialect, sa.dialects.postgresql.JSONB)
def sa_jsonb(_, satype, nullable=True):
    return dt.JSONB(nullable=nullable)

if geospatial_supported:

    @dt.dtype.register(SQLAlchemyDialect, (ga.Geometry, ga.types._GISType))
    def ga_geometry(_, gatype, nullable=True):
        t = gatype.geometry_type
        if t == 'POINT':
            return dt.Point(nullable=nullable)
        if t == 'LINESTRING':
            return dt.LineString(nullable=nullable)
        if t == 'POLYGON':
            return dt.Polygon(nullable=nullable)
        if t == 'MULTILINESTRING':
            return dt.MultiLineString(nullable=nullable)
        if t == 'MULTIPOINT':
            return dt.MultiPoint(nullable=nullable)
        if t == 'MULTIPOLYGON':
            return dt.MultiPolygon(nullable=nullable)
        if t == 'GEOMETRY':
            return dt.Geometry(nullable=nullable)
            raise ValueError("Unrecognized geometry type: {}".format(t))

    "YEAR": "Y",
    "MONTH": "M",
    "DAY": "D",
    "HOUR": "h",
    "MINUTE": "m",
    "SECOND": "s",
    "YEAR TO MONTH": "M",
    "DAY TO HOUR": "h",
    "DAY TO MINUTE": "m",
    "DAY TO SECOND": "s",
    "HOUR TO MINUTE": "m",
    "HOUR TO SECOND": "s",
    "MINUTE TO SECOND": "s",

@dt.dtype.register(PostgreSQLDialect, sa.dialects.postgresql.INTERVAL)
def sa_postgres_interval(_, satype, nullable=True):
    field = satype.fields.upper()
    unit = POSTGRES_FIELD_TO_IBIS_UNIT.get(field, None)
    if unit is None:
        raise ValueError(
            "Unknown PostgreSQL interval field {!r}".format(field)
    elif unit in {"Y", "M"}:
        raise ValueError(
            "Variable length timedeltas are not yet supported with PostgreSQL"
    return dt.Interval(unit=unit, nullable=nullable)

@dt.dtype.register(MySQLDialect, sa.dialects.mysql.DOUBLE)
def sa_mysql_double(_, satype, nullable=True):
    # TODO: handle asdecimal=True
    return dt.Double(nullable=nullable)

@dt.dtype.register(SQLAlchemyDialect, sa.types.String)
def sa_string(_, satype, nullable=True):
    return dt.String(nullable=nullable)

@dt.dtype.register(SQLAlchemyDialect, sa.types.Binary)
def sa_binary(_, satype, nullable=True):
    return dt.Binary(nullable=nullable)

@dt.dtype.register(SQLAlchemyDialect, sa.Time)
def sa_time(_, satype, nullable=True):
    return dt.Time(nullable=nullable)

@dt.dtype.register(SQLAlchemyDialect, sa.Date)
def sa_date(_, satype, nullable=True):
    return dt.Date(nullable=nullable)

@dt.dtype.register(SQLAlchemyDialect, sa.DateTime)
def sa_datetime(_, satype, nullable=True, default_timezone='UTC'):
    timezone = default_timezone if satype.timezone else None
    return dt.Timestamp(timezone=timezone, nullable=nullable)

@dt.dtype.register(SQLAlchemyDialect, sa.ARRAY)
def sa_array(dialect, satype, nullable=True):
    dimensions = satype.dimensions
    if dimensions is not None and dimensions != 1:
        raise NotImplementedError('Nested array types not yet supported')

    value_dtype = dt.dtype(dialect, satype.item_type)
    return dt.Array(value_dtype, nullable=nullable)

def schema_from_table(table, schema=None):
    """Retrieve an ibis schema from a SQLAlchemy ``Table``.

    table : sa.Table

    schema : ibis.expr.datatypes.Schema
        An ibis schema corresponding to the types of the columns in `table`.
    schema = schema if schema is not None else {}
    pairs = []
    for name, column in table.columns.items():
        if name in schema:
            dtype = dt.dtype(schema[name])
            dtype = dt.dtype(
                getattr(table.bind, 'dialect', SQLAlchemyDialect()),
        pairs.append((name, dtype))
    return sch.schema(pairs)

def table_from_schema(name, meta, schema, database: Optional[str] = None):
    # Convert Ibis schema to SQLA table
    columns = []

    for colname, dtype in zip(schema.names, schema.types):
        satype = _to_sqla_type(dtype)
        column = sa.Column(colname, satype, nullable=dtype.nullable)

    return sa.Table(name, meta, schema=database, *columns)

def _variance_reduction(func_name):
    suffix = {'sample': 'samp', 'pop': 'pop'}

    def variance_compiler(t, expr):
        arg, how, where = expr.op().args

        if arg.type().equals(dt.boolean):
            arg = arg.cast('int32')

        func = getattr(
            sa.func, '{}_{}'.format(func_name, suffix.get(how, 'samp'))

        if where is not None:
            arg = where.ifelse(arg, None)
        return func(t.translate(arg))

    return variance_compiler

def infix_op(infix_sym):
    def formatter(t, expr):
        op = expr.op()
        left, right = op.args

        left_arg = t.translate(left)
        right_arg = t.translate(right)
        return left_arg.op(infix_sym)(right_arg)

    return formatter

def fixed_arity(sa_func, arity):
    if isinstance(sa_func, str):
        sa_func = getattr(sa.func, sa_func)

    def formatter(t, expr):
        if arity != len(expr.op().args):
            raise com.IbisError('incorrect number of args')

        return _varargs_call(sa_func, t, expr)

    return formatter

def varargs(sa_func):
    def formatter(t, expr):
        op = expr.op()
        trans_args = [t.translate(arg) for arg in op.arg]
        return sa_func(*trans_args)

    return formatter

def _varargs_call(sa_func, t, expr):
    op = expr.op()
    trans_args = [t.translate(arg) for arg in op.args]
    return sa_func(*trans_args)

def _table_column(t, expr):
    op = expr.op()
    ctx = t.context
    table = op.table

    sa_table = _get_sqla_table(ctx, table)
    out_expr = getattr(sa_table.c,

    # If the column does not originate from the table set in the current SELECT
    # context, we should format as a subquery
    if t.permit_subquery and ctx.is_foreign_expr(table):

    return out_expr

def _get_sqla_table(ctx, table):
    if ctx.has_ref(table):
        ctx_level = ctx
        sa_table = ctx_level.get_table(table)
        while sa_table is None and ctx_level.parent is not ctx_level:
            ctx_level = ctx_level.parent
            sa_table = ctx_level.get_table(table)
        op = table.op()
        if isinstance(op, AlchemyTable):
            sa_table = op.sqla_table
            sa_table = ctx.get_compiled_expr(table)

    return sa_table

def _table_array_view(t, expr):
    ctx = t.context
    table = ctx.get_compiled_expr(expr.op().table)
    return table

def _exists_subquery(t, expr):
    op = expr.op()
    ctx = t.context

    filtered = op.foreign_table.filter(op.predicates).projection(

    sub_ctx = ctx.subcontext()
    clause = to_sqlalchemy(filtered, sub_ctx, exists=True)

    if isinstance(op, transforms.NotExistsSubquery):
        clause = sa.not_(clause)

    return clause

def _cast(t, expr):
    op = expr.op()
    arg, target_type = op.args
    sa_arg = t.translate(arg)
    sa_type = t.get_sqla_type(target_type)

    if isinstance(arg, ir.CategoryValue) and target_type == 'int32':
        return sa_arg
        return sa.cast(sa_arg, sa_type)

def _contains(t, expr):
    op = expr.op()

    left, right = [t.translate(arg) for arg in op.args]

    return left.in_(right)

def _not_contains(t, expr):
    return sa.not_(_contains(t, expr))

def _reduction(sa_func):
    def formatter(t, expr):
        op = expr.op()
        *args, where = op.args

        return _reduction_format(t, sa_func, where, *args)

    return formatter

def _reduction_format(t, sa_func, where, arg, *args):
    if where is not None:
        arg = t.translate(where.ifelse(arg, ibis.NA))
        arg = t.translate(arg)

    return sa_func(arg, *map(t.translate, args))

def _literal(t, expr):
    dtype = expr.type()
    value = expr.op().value

    if isinstance(dtype, dt.Set):
        return list(map(sa.literal, value))

    return sa.literal(value)

def _value_list(t, expr):
    return [t.translate(x) for x in expr.op().values]

def _is_null(t, expr):
    arg = t.translate(expr.op().args[0])
    return arg.is_(sa.null())

def _not_null(t, expr):
    arg = t.translate(expr.op().args[0])
    return arg.isnot(sa.null())

def _round(t, expr):
    op = expr.op()
    arg, digits = op.args
    sa_arg = t.translate(arg)

    f = sa.func.round

    if digits is not None:
        sa_digits = t.translate(digits)
        return f(sa_arg, sa_digits)
        return f(sa_arg)

def _floor_divide(t, expr):
    left, right = map(t.translate, expr.op().args)
    return sa.func.floor(left / right)

def _count_distinct(t, expr):
    arg, where = expr.op().args

    if where is not None:
        sa_arg = t.translate(where.ifelse(arg, None))
        sa_arg = t.translate(arg)

    return sa.func.count(sa_arg.distinct())

def _simple_case(t, expr):
    op = expr.op()

    cases = [op.base == case for case in op.cases]
    return _translate_case(t, cases, op.results, op.default)

def _searched_case(t, expr):
    op = expr.op()
    return _translate_case(t, op.cases, op.results, op.default)

def _translate_case(t, cases, results, default):
    case_args = [t.translate(arg) for arg in cases]
    result_args = [t.translate(arg) for arg in results]

    whens = zip(case_args, result_args)
    default = t.translate(default)

    return, else_=default)

def _negate(t, expr):
    op = expr.op()
    (arg,) = map(t.translate, op.args)
    return sa.not_(arg) if isinstance(expr, ir.BooleanValue) else -arg

def unary(sa_func):
    return fixed_arity(sa_func, 1)

def _string_like(t, expr):
    arg, pattern, escape = expr.op().args
    result = t.translate(arg).like(t.translate(pattern), escape=escape)
    return result

_cumulative_to_reduction = {
    ops.CumulativeSum: ops.Sum,
    ops.CumulativeMin: ops.Min,
    ops.CumulativeMax: ops.Max,
    ops.CumulativeMean: ops.Mean,
    ops.CumulativeAny: ops.Any,
    ops.CumulativeAll: ops.All,

def _cumulative_to_window(translator, expr, window):
    win = W.cumulative_window()
    win = win.group_by(window._group_by).order_by(window._order_by)

    op = expr.op()

    klass = _cumulative_to_reduction[type(op)]
    new_op = klass(*op.args)
    new_expr = expr._factory(new_op, name=expr._name)

    if type(new_op) in translator._rewrites:
        new_expr = translator._rewrites[type(new_op)](new_expr)

    return L.windowize_function(new_expr, win)

def _window(t, expr):
    op = expr.op()

    arg, window = op.args
    reduction = t.translate(arg)

    window_op = arg.op()

    _require_order_by = (

    if isinstance(window_op, ops.CumulativeOp):
        arg = _cumulative_to_window(t, arg, window)
        return t.translate(arg)

    if window.max_lookback is not None:
        raise NotImplementedError(
            'Rows with max lookback is not implemented '
            'for SQLAlchemy-based backends.'

    # Some analytic functions need to have the expression of interest in
    # the ORDER BY part of the window clause
    if isinstance(window_op, _require_order_by) and not window._order_by:
        order_by = t.translate(window_op.args[0])
        order_by = list(map(t.translate, window._order_by))

    partition_by = list(map(t.translate, window._group_by))

    frame_clause_not_allowed = (

    how = {'range': 'range_'}.get(,
    preceding = window.preceding
    additional_params = (
        if isinstance(window_op, frame_clause_not_allowed)
        else {
            how: (
                -preceding if preceding is not None else preceding,
    result = reduction.over(
        partition_by=partition_by, order_by=order_by, **additional_params

    if isinstance(
        window_op, (ops.RowNumber, ops.DenseRank, ops.MinRank, ops.NTile)
        return result - 1
        return result

def _lag(t, expr):
    arg, offset, default = expr.op().args
    if default is not None:
        raise NotImplementedError()

    sa_arg = t.translate(arg)
    sa_offset = t.translate(offset) if offset is not None else 1
    return sa.func.lag(sa_arg, sa_offset)

def _lead(t, expr):
    arg, offset, default = expr.op().args
    if default is not None:
        raise NotImplementedError()
    sa_arg = t.translate(arg)
    sa_offset = t.translate(offset) if offset is not None else 1
    return sa.func.lead(sa_arg, sa_offset)

def _ntile(t, expr):
    op = expr.op()
    args = op.args
    arg, buckets = map(t.translate, args)
    return sa.func.ntile(buckets)

_operation_registry = {
    ops.And: fixed_arity(sql.and_, 2),
    ops.Or: fixed_arity(sql.or_, 2),
    ops.Not: unary(sa.not_),
    ops.Abs: unary(sa.func.abs),
    ops.Cast: _cast,
    ops.Coalesce: varargs(sa.func.coalesce),
    ops.NullIf: fixed_arity(sa.func.nullif, 2),
    ops.Contains: _contains,
    ops.NotContains: _not_contains,
    ops.Count: _reduction(sa.func.count),
    ops.Sum: _reduction(sa.func.sum),
    ops.Mean: _reduction(sa.func.avg),
    ops.Min: _reduction(sa.func.min),
    ops.Max: _reduction(sa.func.max),
    ops.CountDistinct: _count_distinct,
    ops.GroupConcat: _reduction(sa.func.group_concat),
    ops.Between: fixed_arity(sa.between, 3),
    ops.IsNull: _is_null,
    ops.NotNull: _not_null,
    ops.Negate: _negate,
    ops.Round: _round,
    ops.TypeOf: unary(sa.func.typeof),
    ops.Literal: _literal,
    ops.ValueList: _value_list,
    ops.NullLiteral: lambda *args: sa.null(),
    ops.SimpleCase: _simple_case,
    ops.SearchedCase: _searched_case,
    ops.TableColumn: _table_column,
    ops.TableArrayView: _table_array_view,
    transforms.ExistsSubquery: _exists_subquery,
    transforms.NotExistsSubquery: _exists_subquery,
    # miscellaneous varargs
    ops.Least: varargs(sa.func.least),
    ops.Greatest: varargs(sa.func.greatest),
    # string
    ops.LPad: fixed_arity(sa.func.lpad, 3),
    ops.RPad: fixed_arity(sa.func.rpad, 3),
    ops.Strip: unary(sa.func.trim),
    ops.LStrip: unary(sa.func.ltrim),
    ops.RStrip: unary(sa.func.rtrim),
    ops.Repeat: fixed_arity(sa.func.repeat, 2),
    ops.Reverse: unary(sa.func.reverse),
    ops.StrRight: fixed_arity(sa.func.right, 2),
    ops.Lowercase: unary(sa.func.lower),
    ops.Uppercase: unary(sa.func.upper),
    ops.StringAscii: unary(sa.func.ascii),
    ops.StringLength: unary(sa.func.length),
    ops.StringReplace: fixed_arity(sa.func.replace, 3),
    ops.StringSQLLike: _string_like,
    # math
    ops.Ln: unary(sa.func.ln),
    ops.Exp: unary(sa.func.exp),
    ops.Sign: unary(sa.func.sign),
    ops.Sqrt: unary(sa.func.sqrt),
    ops.Ceil: unary(sa.func.ceil),
    ops.Floor: unary(sa.func.floor),
    ops.Power: fixed_arity(sa.func.pow, 2),
    ops.FloorDivide: _floor_divide,

# TODO: unit tests for each of these
_binary_ops = {
    # Binary arithmetic
    ops.Add: operator.add,
    ops.Subtract: operator.sub,
    ops.Multiply: operator.mul,
    ops.Divide: operator.truediv,
    ops.Modulus: operator.mod,
    # Comparisons
    ops.Equals: operator.eq,
    ops.LessEqual: operator.le,
    ops.IdenticalTo: lambda x, y: x.op('IS NOT DISTINCT FROM')(y),
    # Boolean comparisons
    # TODO

_window_functions = {
    ops.Lag: _lag,
    ops.Lead: _lead,
    ops.NTile: _ntile,
    ops.FirstValue: unary(sa.func.first_value),
    ops.LastValue: unary(sa.func.last_value),
    ops.RowNumber: fixed_arity(lambda: sa.func.row_number(), 0),
    ops.DenseRank: unary(lambda arg: sa.func.dense_rank()),
    ops.MinRank: unary(lambda arg: sa.func.rank()),
    ops.PercentRank: unary(lambda arg: sa.func.percent_rank()),
    ops.WindowOp: _window,
    ops.CumulativeOp: _window,
    ops.CumulativeMax: unary(sa.func.max),
    ops.CumulativeMin: unary(sa.func.min),
    ops.CumulativeSum: unary(sa.func.sum),
    ops.CumulativeMean: unary(sa.func.avg),

if geospatial_supported:
    _geospatial_functions = {
        ops.GeoArea: unary(sa.func.ST_Area),
        ops.GeoAsBinary: unary(sa.func.ST_AsBinary),
        ops.GeoAsEWKB: unary(sa.func.ST_AsEWKB),
        ops.GeoAsEWKT: unary(sa.func.ST_AsEWKT),
        ops.GeoAsText: unary(sa.func.ST_AsText),
        ops.GeoAzimuth: fixed_arity(sa.func.ST_Azimuth, 2),
        ops.GeoBuffer: fixed_arity(sa.func.ST_Buffer, 2),
        ops.GeoCentroid: unary(sa.func.ST_Centroid),
        ops.GeoContains: fixed_arity(sa.func.ST_Contains, 2),
        ops.GeoContainsProperly: fixed_arity(sa.func.ST_Contains, 2),
        ops.GeoCovers: fixed_arity(sa.func.ST_Covers, 2),
        ops.GeoCoveredBy: fixed_arity(sa.func.ST_CoveredBy, 2),
        ops.GeoCrosses: fixed_arity(sa.func.ST_Crosses, 2),
        ops.GeoDFullyWithin: fixed_arity(sa.func.ST_DFullyWithin, 3),
        ops.GeoDifference: fixed_arity(sa.func.ST_Difference, 2),
        ops.GeoDisjoint: fixed_arity(sa.func.ST_Disjoint, 2),
        ops.GeoDistance: fixed_arity(sa.func.ST_Distance, 2),
        ops.GeoDWithin: fixed_arity(sa.func.ST_DWithin, 3),
        ops.GeoEnvelope: unary(sa.func.ST_Envelope),
        ops.GeoEquals: fixed_arity(sa.func.ST_Equals, 2),
        ops.GeoGeometryN: fixed_arity(sa.func.ST_GeometryN, 2),
        ops.GeoGeometryType: unary(sa.func.ST_GeometryType),
        ops.GeoIntersection: fixed_arity(sa.func.ST_Intersection, 2),
        ops.GeoIntersects: fixed_arity(sa.func.ST_Intersects, 2),
        ops.GeoIsValid: unary(sa.func.ST_IsValid),
        ops.GeoLineLocatePoint: fixed_arity(sa.func.ST_LineLocatePoint, 2),
        ops.GeoLineMerge: unary(sa.func.ST_LineMerge),
        ops.GeoLineSubstring: fixed_arity(sa.func.ST_LineSubstring, 3),
        ops.GeoLength: unary(sa.func.ST_Length),
        ops.GeoNPoints: unary(sa.func.ST_NPoints),
        ops.GeoOrderingEquals: fixed_arity(sa.func.ST_OrderingEquals, 2),
        ops.GeoOverlaps: fixed_arity(sa.func.ST_Overlaps, 2),
        ops.GeoPerimeter: unary(sa.func.ST_Perimeter),
        ops.GeoSimplify: fixed_arity(sa.func.ST_Simplify, 3),
        ops.GeoSRID: unary(sa.func.ST_SRID),
        ops.GeoSetSRID: fixed_arity(sa.func.ST_SetSRID, 2),
        ops.GeoTouches: fixed_arity(sa.func.ST_Touches, 2),
        ops.GeoTransform: fixed_arity(sa.func.ST_Transform, 2),
        ops.GeoUnaryUnion: unary(sa.func.ST_Union),
        ops.GeoUnion: fixed_arity(sa.func.ST_Union, 2),
        ops.GeoWithin: fixed_arity(sa.func.ST_Within, 2),
        ops.GeoX: unary(sa.func.ST_X),
        ops.GeoY: unary(sa.func.ST_Y),
        # Missing Geospatial ops:
        #   ST_AsGML
        #   ST_AsGeoJSON
        #   ST_AsKML
        #   ST_AsRaster
        #   ST_AsSVG
        #   ST_AsTWKB
        #   ST_Distance_Sphere
        #   ST_Dump
        #   ST_DumpPoints
        #   ST_GeogFromText
        #   ST_GeomFromEWKB
        #   ST_GeomFromEWKT
        #   ST_GeomFromText


for _k, _v in _binary_ops.items():
    _operation_registry[_k] = fixed_arity(_v, 2)

class AlchemySelectBuilder(comp.SelectBuilder):
    def _select_class(self):
        return AlchemySelect

    def _convert_group_by(self, exprs):
        return exprs

class AlchemyContext(comp.QueryContext):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._table_objects = {}

    def collapse(self, queries):
        if isinstance(queries, str):
            return queries

        if len(queries) > 1:
            raise NotImplementedError(
                'Only a single query is supported for SQLAlchemy backends'
        return queries[0]

    def subcontext(self):
        return type(self)(
            dialect=self.dialect, parent=self, params=self.params

    def _to_sql(self, expr, ctx):
        return to_sqlalchemy(expr, ctx)

    def _compile_subquery(self, expr):
        sub_ctx = self.subcontext()
        return self._to_sql(expr, sub_ctx)

    def has_table(self, expr, parent_contexts=False):
        key = self._get_table_key(expr)
        return self._key_in(
            key, '_table_objects', parent_contexts=parent_contexts

    def set_table(self, expr, obj):
        key = self._get_table_key(expr)
        self._table_objects[key] = obj

    def get_table(self, expr):
        Get the memoized SQLAlchemy expression object
        return self._get_table_item('_table_objects', expr)

class AlchemyUnion(Union):
    def compile(self):
        def reduce_union(left, right, distincts=iter(self.distincts)):
            distinct = next(distincts)
            sa_func = sa.union if distinct else sa.union_all
            return sa_func(left, right)

        context = self.context
        selects = []

        for table in self.tables:
            table_set = context.get_compiled_expr(table)

        return functools.reduce(reduce_union, selects)

class AlchemyQueryBuilder(comp.QueryBuilder):

    select_builder = AlchemySelectBuilder
    union_class = AlchemyUnion

def to_sqlalchemy(expr, context, exists=False):
    ast = build_ast(expr, context)
    query = ast.queries[0]

    if exists:
        query.exists = exists

    return query.compile()

def build_ast(expr, context):
    builder = AlchemyQueryBuilder(expr, context)
    return builder.get_result()

class AlchemyDatabaseSchema(Database):
    def __init__(self, name, database):

        name : str
        database : AlchemyDatabase
        """ = name
        self.database = database
        self.client = database.client

    def __repr__(self):
        return "Schema({!r})".format(

    def drop(self, force=False):
        Drop the schema

        force : boolean, default False
          Drop any objects if they exist, and do not fail if the schema does
          not exist.
        raise NotImplementedError(
            "Drop is not implemented yet for sqlalchemy schemas"

    def table(self, name):
        Return a table expression referencing a table in this schema

        table : TableExpr
        qualified_name = self._qualify(name)
        return self.database.table(qualified_name,

    def list_tables(self, like=None):
        return self.database.list_tables(
  , like=self._qualify_like(like)

class AlchemyDatabase(Database):

    client : AlchemyClient


    schema_class = AlchemyDatabaseSchema

    def table(self, name, schema=None):
        return self.client.table(name, schema=schema)

    def list_tables(self, like=None, schema=None):
        return self.client.list_tables(
            schema=schema, like=self._qualify_like(like),

    def schema(self, name):
        return self.schema_class(name, self)

class AlchemyTable(ops.DatabaseTable):
    def __init__(self, table, source, schema=None):
        schema = sch.infer(table, schema=schema)
        super().__init__(, schema, source)
        self.sqla_table = table

class AlchemyExprTranslator(comp.ExprTranslator):

    _registry = _operation_registry
    _rewrites = comp.ExprTranslator._rewrites.copy()
    _type_map = _ibis_type_to_sqla

    context_class = AlchemyContext

    def name(self, translated, name, force=True):
        if hasattr(translated, 'label'):
            return translated.label(name)
        return translated

    def get_sqla_type(self, data_type):
        return _to_sqla_type(data_type, type_map=self._type_map)

rewrites = AlchemyExprTranslator.rewrites
compiles = AlchemyExprTranslator.compiles

class AlchemyQuery(Query):
    def _fetch(self, cursor):
        df = pd.DataFrame.from_records(
        schema = self.schema()
        return _maybe_to_geodataframe(schema.apply_to(df), schema)

class AlchemyDialect(Dialect):

    translator = AlchemyExprTranslator

def invalidates_reflection_cache(f):
    """Invalidate the SQLAlchemy reflection cache if `f` performs an operation
    that mutates database or table metadata such as ``CREATE TABLE``,
    ``DROP TABLE``, etc.

    f : callable
        A method on :class:`ibis.sql.alchemy.AlchemyClient`

    def wrapped(self, *args, **kwargs):
        result = f(self, *args, **kwargs)

        # only invalidate the cache after we've succesfully called the wrapped
        # function
        self._reflection_cache_is_dirty = True
        return result

    return wrapped

class AlchemyClient(SQLClient):

    dialect = AlchemyDialect
    query_class = AlchemyQuery
    has_attachment = False

    def __init__(self, con: sa.engine.Engine) -> None:
        self.con = con
        self.meta = sa.MetaData(bind=con)
        self._inspector = sa.inspect(con)
        self._reflection_cache_is_dirty = False
        self._schemas = {}

    def inspector(self):
        if self._reflection_cache_is_dirty:
        return self._inspector

    def begin(self):
        with self.con.begin() as bind:
            yield bind

    def create_table(self, name, expr=None, schema=None, database=None):
        if database == self.database_name:
            # avoid fully qualified name
            database = None

        if database is not None:
            raise NotImplementedError(
                'Creating tables from a different database is not yet '

        if expr is None and schema is None:
            raise ValueError('You must pass either an expression or a schema')

        if expr is not None and schema is not None:
            if not expr.schema().equals(ibis.schema(schema)):
                raise TypeError(
                    'Expression schema is not equal to passed schema. '
                    'Try passing the expression without the schema'
        if schema is None:
            schema = expr.schema()

        self._schemas[self._fully_qualified_name(name, database)] = schema
        t = self._table_from_schema(
            name, schema, database=database or self.current_database

        with self.begin() as bind:
            if expr is not None:
                    t.insert().from_select(list(expr.columns), expr.compile())

    def _columns_from_schema(
        self, name: str, schema: sch.Schema
    ) -> List[sa.Column]:
        return [
            sa.Column(colname, _to_sqla_type(dtype), nullable=dtype.nullable)
            for colname, dtype in zip(schema.names, schema.types)

    def _table_from_schema(
        self, name: str, schema: sch.Schema, database: Optional[str] = None
    ) -> sa.Table:
        columns = self._columns_from_schema(name, schema)
        return sa.Table(name, self.meta, *columns)

    def drop_table(
        table_name: str,
        database: Optional[str] = None,
        force: bool = False,
    ) -> None:
        if database == self.database_name:
            # avoid fully qualified name
            database = None

        if database is not None:
            raise NotImplementedError(
                'Dropping tables from a different database is not yet '

        t = self._get_sqla_table(table_name, schema=database, autoload=False)

        assert (
            not t.exists()
        ), 'Something went wrong during DROP of table {!r}'.format(


        qualified_name = self._fully_qualified_name(table_name, database)

            del self._schemas[qualified_name]
        except KeyError:  # schemas won't be cached if created with raw_sql

    def load_data(
        table_name: str,
        data: pd.DataFrame,
        database: str = None,
        if_exists: str = 'fail',
        Load data from a dataframe to the backend.

        table_name : string
        data: pandas.DataFrame
        database : string, optional
        if_exists : string, optional, default 'fail'
            The values available are: {‘fail’, ‘replace’, ‘append’}

            Loading data to a table from a different database is not
            yet implemented
        if database == self.database_name:
            # avoid fully qualified name
            database = None

        if database is not None:
            raise NotImplementedError(
                'Loading data to a table from a different database is not '
                'yet implemented'

        params = {}
        if self.has_attachment:
            # for database with attachment
            # see:
            params['schema'] = self.database_name


    def truncate_table(
        self, table_name: str, database: Optional[str] = None
    ) -> None:
        t = self._get_sqla_table(table_name, schema=database)

    def list_tables(
        like: Optional[str] = None,
        database: Optional[str] = None,
        schema: Optional[str] = None,
    ) -> List[str]:
        """List tables/views in the current or indicated database.

            Checks for this string contained in name
            If not passed, uses the current database
            The schema namespace that tables should be listed from


        inspector = self.inspector
        # inspector returns a mutable version of its names, so make a copy.
        names = inspector.get_table_names(schema=schema).copy()
        if like is not None:
            names = [x for x in names if like in x]
        return sorted(names)

    def _execute(self, query: str, results: bool = True):
        return AlchemyProxy(self.con.execute(query))

    def raw_sql(self, query: str, results: bool = False):
        return super().raw_sql(query, results=results)

    def _build_ast(self, expr, context):
        return build_ast(expr, context)

    def _get_sqla_table(self, name, schema=None, autoload=True):
        return sa.Table(name, self.meta, schema=schema, autoload=autoload)

    def _sqla_table_to_expr(self, table):
        node = self.table_class(table, self)
        return self.table_expr_class(node)

    def version(self):
        vstring = '.'.join(map(str, self.con.dialect.server_version_info))
        return parse_version(vstring)

class AlchemySelect(Select):
    def __init__(self, *args, **kwargs):
        self.exists = kwargs.pop('exists', False)
        super().__init__(*args, **kwargs)

    def compile(self):
        # Can't tell if this is a hack or not. Revisit later


        frag = self._compile_table_set()
        steps = [

        for step in steps:
            frag = step(frag)

        return frag

    def _compile_subqueries(self):
        if not self.subqueries:

        for expr in self.subqueries:
            result = self.context.get_compiled_expr(expr)
            alias = self.context.get_ref(expr)
            result = result.cte(alias)
            self.context.set_table(expr, result)

    def _compile_table_set(self):
        if self.table_set is not None:
            helper = _AlchemyTableSet(self, self.table_set)
            return helper.get_result()
            return None

    def _add_select(self, table_set):
        to_select = []

        has_select_star = False
        for expr in self.select_set:
            if isinstance(expr, ir.ValueExpr):
                arg = self._translate(expr, named=True)
            elif isinstance(expr, ir.TableExpr):
                if expr.equals(self.table_set):
                    cached_table = self.context.get_table(expr)
                    if cached_table is None:
                        # the select * case from materialized join
                        has_select_star = True
                        arg = table_set
                    arg = self.context.get_table(expr)
                    if arg is None:
                        raise ValueError(expr)


        if has_select_star:
            if table_set is None:
                raise ValueError('table_set cannot be None here')

            clauses = [table_set] + to_select
            clauses = to_select

        if self.exists:
            result = sa.exists(clauses)
            result =

        if self.distinct:
            result = result.distinct()

        if not has_select_star:
            if table_set is not None:
                return result.select_from(table_set)
                return result
            return result

    def _add_groupby(self, fragment):
        # GROUP BY and HAVING
        if not len(self.group_by):
            return fragment

        group_keys = [self._translate(arg) for arg in self.group_by]
        fragment = fragment.group_by(*group_keys)

        if len(self.having) > 0:
            having_args = [self._translate(arg) for arg in self.having]
            having_clause = functools.reduce(sql.and_, having_args)
            fragment = fragment.having(having_clause)

        return fragment

    def _add_where(self, fragment):
        if not len(self.where):
            return fragment

        args = [
            self._translate(pred, permit_subquery=True) for pred in self.where
        clause = functools.reduce(sql.and_, args)
        return fragment.where(clause)

    def _add_order_by(self, fragment):
        if not len(self.order_by):
            return fragment

        clauses = []
        for expr in self.order_by:
            key = expr.op()
            sort_expr = key.expr

            # here we have to determine if key.expr is in the select set (as it
            # will be in the case of order_by fused with an aggregation
            if _can_lower_sort_column(self.table_set, sort_expr):
                arg = sort_expr.get_name()
                arg = self._translate(sort_expr)

            if not key.ascending:
                arg = sa.desc(arg)


        return fragment.order_by(*clauses)

    def _among_select_set(self, expr):
        for other in self.select_set:
            if expr.equals(other):
                return True
        return False

    def _add_limit(self, fragment):
        if self.limit is None:
            return fragment

        n, offset = self.limit['n'], self.limit['offset']
        fragment = fragment.limit(n)
        if offset is not None and offset != 0:
            fragment = fragment.offset(offset)

        return fragment

    def dialect(self):
        return self.context.dialect

class _AlchemyTableSet(TableSetFormatter):
    def get_result(self):
        # Got to unravel the join stack; the nesting order could be
        # arbitrary, so we do a depth first search and push the join tokens
        # and predicates onto a flat list, then format them
        op = self.expr.op()

        if isinstance(op, ops.Join):

        result = self.join_tables[0]
        for jtype, table, preds in zip(
            self.join_types, self.join_tables[1:], self.join_predicates
            if len(preds):
                sqla_preds = [self._translate(pred) for pred in preds]
                onclause = functools.reduce(sql.and_, sqla_preds)
                onclause = None

            if jtype in (ops.InnerJoin, ops.CrossJoin):
                result = result.join(table, onclause)
            elif jtype is ops.LeftJoin:
                result = result.join(table, onclause, isouter=True)
            elif jtype is ops.RightJoin:
                result = table.join(result, onclause, isouter=True)
            elif jtype is ops.OuterJoin:
                result = result.outerjoin(table, onclause, full=True)
            elif jtype is ops.LeftSemiJoin:
                result =[result]).where(
            elif jtype is ops.LeftAntiJoin:
                result =[result]).where(
                raise NotImplementedError(jtype)

        return result

    def _get_join_type(self, op):
        return type(op)

    def _format_table(self, expr):
        ctx = self.context
        ref_expr = expr
        op = ref_op = expr.op()

        if isinstance(op, ops.SelfReference):
            ref_expr = op.table
            ref_op = ref_expr.op()

        alias = ctx.get_ref(expr)

        if isinstance(ref_op, AlchemyTable):
            result = ref_op.sqla_table
        elif isinstance(ref_op, ops.UnboundTable):
            # use SQLAlchemy's TableClause and ColumnClause for unbound tables
            schema = ref_op.schema
            result = sa.table(
       if is not None else ctx.get_ref(expr),
                    sa.column(n, _to_sqla_type(t))
                    for n, t in zip(schema.names, schema.types)
            # A subquery
            if ctx.is_extracted(ref_expr):
                # Was put elsewhere, e.g. WITH block, we just need to grab
                # its alias
                alias = ctx.get_ref(expr)

                # hack
                if isinstance(op, ops.SelfReference):
                    table = ctx.get_table(ref_expr)
                    self_ref = table.alias(alias)
                    ctx.set_table(expr, self_ref)
                    return self_ref
                    return ctx.get_table(expr)

            result = ctx.get_compiled_expr(expr)
            alias = ctx.get_ref(expr)

        result = result.alias(alias)
        ctx.set_table(expr, result)
        return result

def _can_lower_sort_column(table_set, expr):
    # TODO(wesm): This code is pending removal through cleaner internal
    # semantics

    # we can currently sort by just-appeared aggregate metrics, but the way
    # these are references in the expression DSL is as a SortBy (blocking
    # table operation) on an aggregation. There's a hack in _collect_SortBy
    # in the generic SQL compiler that "fuses" the sort with the
    # aggregation so they appear in same query. It's generally for
    # cosmetics and doesn't really affect query semantics.
    bases = ops.find_all_base_tables(expr)
    if len(bases) > 1:
        return False

    base = list(bases.values())[0]
    base_op = base.op()

    if isinstance(base_op, ops.Aggregation):
        return base_op.table.equals(table_set)
    elif isinstance(base_op, ops.Selection):
        return base.equals(table_set)
        return False

class AlchemyProxy:
    Wraps a SQLAlchemy ResultProxy and ensures that .close() is called on
    garbage collection

    def __init__(self, proxy):
        self.proxy = proxy

    def __del__(self):

    def _close_cursor(self):

    def __enter__(self):
        return self

    def __exit__(self, type, value, tb):

    def fetchall(self):
        return self.proxy.fetchall()

def _nullifzero(expr):
    arg = expr.op().args[0]
    return (arg == 0).ifelse(ibis.NA, arg)

def _true_divide(t, expr):
    op = expr.op()
    left, right = args = op.args

    if util.all_of(args, ir.IntegerValue):
        return t.translate(left.div(right.cast('double')))

    return fixed_arity(lambda x, y: x / y, 2)(t, expr)

def _sort_key(t, expr):
    # We need to define this for window functions that have an order by
    by, ascending = expr.op().args
    sort_direction = sa.asc if ascending else sa.desc
    return sort_direction(t.translate(by))

def _maybe_to_geodataframe(df, schema):
    If the required libraries for geospatial support are installed, and if a
    geospatial column is present in the dataframe, convert it to a

    def to_shapely(row, name):
        return shape.to_shape(row[name]) if row[name] is not None else None

    if len(df) and geospatial_supported:
        geom_col = None
        for name, dtype in schema.items():
            if isinstance(dtype, dt.GeoSpatial):
                geom_col = geom_col or name
                df[name] = df.apply(lambda x: to_shapely(x, name), axis=1)
        if geom_col:
            df = geopandas.GeoDataFrame(df, geometry=geom_col)
    return df