import contextlib
import functools
import operator
from typing import List, Optional

import pandas as pd
import sqlalchemy as sa
import sqlalchemy.sql as sql
from pkg_resources import parse_version
from sqlalchemy.dialects.mysql.base import MySQLDialect
from sqlalchemy.dialects.postgresql.base import PGDialect as PostgreSQLDialect
from sqlalchemy.dialects.sqlite.base import SQLiteDialect
from sqlalchemy.engine.interfaces import Dialect as SQLAlchemyDialect

import ibis
import ibis.common.exceptions as com
import ibis.expr.analysis as L
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.schema as sch
import ibis.expr.types as ir
import ibis.expr.window as W
import ibis.sql.compiler as comp
import ibis.sql.transforms as transforms
import ibis.util as util
from ibis.client import Database, Query, SQLClient
from ibis.sql.compiler import Dialect, Select, TableSetFormatter, Union

geospatial_supported = False
try:
    import geoalchemy2 as ga
    import geoalchemy2.shape as shape
    import geopandas

    geospatial_supported = True
except ImportError:
    pass


# TODO(cleanup)
_ibis_type_to_sqla = {
    dt.Null: sa.types.NullType,
    dt.Date: sa.Date,
    dt.Time: sa.Time,
    dt.Boolean: sa.Boolean,
    dt.Binary: sa.Binary,
    dt.String: sa.Text,
    dt.Decimal: sa.NUMERIC,
    # Mantissa-based
    dt.Float: sa.Float(precision=24),
    dt.Double: sa.Float(precision=53),
    dt.Int8: sa.SmallInteger,
    dt.Int16: sa.SmallInteger,
    dt.Int32: sa.Integer,
    dt.Int64: sa.BigInteger,
}


def _to_sqla_type(itype, type_map=None):
    if type_map is None:
        type_map = _ibis_type_to_sqla
    if isinstance(itype, dt.Decimal):
        return sa.types.NUMERIC(itype.precision, itype.scale)
    elif isinstance(itype, dt.Date):
        return sa.Date()
    elif isinstance(itype, dt.Timestamp):
        # SQLAlchemy DateTimes do not store the timezone, just whether the db
        # supports timezones.
        return sa.TIMESTAMP(bool(itype.timezone))
    elif isinstance(itype, dt.Array):
        ibis_type = itype.value_type
        if not isinstance(ibis_type, (dt.Primitive, dt.String)):
            raise TypeError(
                'Type {} is not a primitive type or string type'.format(
                    ibis_type
                )
            )
        return sa.ARRAY(_to_sqla_type(ibis_type, type_map=type_map))
    elif geospatial_supported and isinstance(itype, dt.GeoSpatial):
        if itype.geotype == 'geometry':
            return ga.Geometry
        elif itype.geotype == 'geography':
            return ga.Geography
        else:
            return ga.types._GISType
    else:
        return type_map[type(itype)]


@dt.dtype.register(SQLAlchemyDialect, sa.types.NullType)
def sa_null(_, satype, nullable=True):
    return dt.null


@dt.dtype.register(SQLAlchemyDialect, sa.types.Boolean)
def sa_boolean(_, satype, nullable=True):
    return dt.Boolean(nullable=nullable)


@dt.dtype.register(MySQLDialect, sa.dialects.mysql.NUMERIC)
def sa_mysql_numeric(_, satype, nullable=True):
    # https://dev.mysql.com/doc/refman/8.0/en/fixed-point-types.html
    return dt.Decimal(
        satype.precision or 10, satype.scale or 0, nullable=nullable
    )


@dt.dtype.register(PostgreSQLDialect, sa.dialects.postgresql.NUMERIC)
def sa_postgres_numeric(_, satype, nullable=True):
    # PostgreSQL allows any precision for numeric values if not specified,
    # up to the implementation limit. Here, default to the maximum value that
    # can be specified by the user. The scale defaults to zero.
    # https://www.postgresql.org/docs/10/datatype-numeric.html
    return dt.Decimal(
        satype.precision or 1000, satype.scale or 0, nullable=nullable
    )


@dt.dtype.register(SQLAlchemyDialect, sa.types.Numeric)
@dt.dtype.register(SQLiteDialect, sa.dialects.sqlite.NUMERIC)
def sa_numeric(_, satype, nullable=True):
    return dt.Decimal(satype.precision, satype.scale, nullable=nullable)


@dt.dtype.register(SQLAlchemyDialect, sa.types.SmallInteger)
def sa_smallint(_, satype, nullable=True):
    return dt.Int16(nullable=nullable)


@dt.dtype.register(SQLAlchemyDialect, sa.types.Integer)
def sa_integer(_, satype, nullable=True):
    return dt.Int32(nullable=nullable)


@dt.dtype.register(SQLAlchemyDialect, sa.dialects.mysql.TINYINT)
def sa_mysql_tinyint(_, satype, nullable=True):
    return dt.Int8(nullable=nullable)


@dt.dtype.register(SQLAlchemyDialect, sa.types.BigInteger)
def sa_bigint(_, satype, nullable=True):
    return dt.Int64(nullable=nullable)


@dt.dtype.register(SQLAlchemyDialect, sa.types.Float)
def sa_float(_, satype, nullable=True):
    return dt.Float(nullable=nullable)


@dt.dtype.register(SQLiteDialect, sa.types.Float)
@dt.dtype.register(PostgreSQLDialect, sa.dialects.postgresql.DOUBLE_PRECISION)
def sa_double(_, satype, nullable=True):
    return dt.Double(nullable=nullable)


@dt.dtype.register(PostgreSQLDialect, sa.dialects.postgresql.UUID)
def sa_uuid(_, satype, nullable=True):
    return dt.UUID(nullable=nullable)


@dt.dtype.register(PostgreSQLDialect, sa.dialects.postgresql.JSON)
def sa_json(_, satype, nullable=True):
    return dt.JSON(nullable=nullable)


@dt.dtype.register(PostgreSQLDialect, sa.dialects.postgresql.JSONB)
def sa_jsonb(_, satype, nullable=True):
    return dt.JSONB(nullable=nullable)


if geospatial_supported:

    @dt.dtype.register(SQLAlchemyDialect, (ga.Geometry, ga.types._GISType))
    def ga_geometry(_, gatype, nullable=True):
        t = gatype.geometry_type
        if t == 'POINT':
            return dt.Point(nullable=nullable)
        if t == 'LINESTRING':
            return dt.LineString(nullable=nullable)
        if t == 'POLYGON':
            return dt.Polygon(nullable=nullable)
        if t == 'MULTILINESTRING':
            return dt.MultiLineString(nullable=nullable)
        if t == 'MULTIPOINT':
            return dt.MultiPoint(nullable=nullable)
        if t == 'MULTIPOLYGON':
            return dt.MultiPolygon(nullable=nullable)
        if t == 'GEOMETRY':
            return dt.Geometry(nullable=nullable)
        else:
            raise ValueError("Unrecognized geometry type: {}".format(t))


POSTGRES_FIELD_TO_IBIS_UNIT = {
    "YEAR": "Y",
    "MONTH": "M",
    "DAY": "D",
    "HOUR": "h",
    "MINUTE": "m",
    "SECOND": "s",
    "YEAR TO MONTH": "M",
    "DAY TO HOUR": "h",
    "DAY TO MINUTE": "m",
    "DAY TO SECOND": "s",
    "HOUR TO MINUTE": "m",
    "HOUR TO SECOND": "s",
    "MINUTE TO SECOND": "s",
}


@dt.dtype.register(PostgreSQLDialect, sa.dialects.postgresql.INTERVAL)
def sa_postgres_interval(_, satype, nullable=True):
    field = satype.fields.upper()
    unit = POSTGRES_FIELD_TO_IBIS_UNIT.get(field, None)
    if unit is None:
        raise ValueError(
            "Unknown PostgreSQL interval field {!r}".format(field)
        )
    elif unit in {"Y", "M"}:
        raise ValueError(
            "Variable length timedeltas are not yet supported with PostgreSQL"
        )
    return dt.Interval(unit=unit, nullable=nullable)


@dt.dtype.register(MySQLDialect, sa.dialects.mysql.DOUBLE)
def sa_mysql_double(_, satype, nullable=True):
    # TODO: handle asdecimal=True
    return dt.Double(nullable=nullable)


@dt.dtype.register(SQLAlchemyDialect, sa.types.String)
def sa_string(_, satype, nullable=True):
    return dt.String(nullable=nullable)


@dt.dtype.register(SQLAlchemyDialect, sa.types.Binary)
def sa_binary(_, satype, nullable=True):
    return dt.Binary(nullable=nullable)


@dt.dtype.register(SQLAlchemyDialect, sa.Time)
def sa_time(_, satype, nullable=True):
    return dt.Time(nullable=nullable)


@dt.dtype.register(SQLAlchemyDialect, sa.Date)
def sa_date(_, satype, nullable=True):
    return dt.Date(nullable=nullable)


@dt.dtype.register(SQLAlchemyDialect, sa.DateTime)
def sa_datetime(_, satype, nullable=True, default_timezone='UTC'):
    timezone = default_timezone if satype.timezone else None
    return dt.Timestamp(timezone=timezone, nullable=nullable)


@dt.dtype.register(SQLAlchemyDialect, sa.ARRAY)
def sa_array(dialect, satype, nullable=True):
    dimensions = satype.dimensions
    if dimensions is not None and dimensions != 1:
        raise NotImplementedError('Nested array types not yet supported')

    value_dtype = dt.dtype(dialect, satype.item_type)
    return dt.Array(value_dtype, nullable=nullable)


@sch.infer.register(sa.Table)
def schema_from_table(table, schema=None):
    """Retrieve an ibis schema from a SQLAlchemy ``Table``.

    Parameters
    ----------
    table : sa.Table

    Returns
    -------
    schema : ibis.expr.datatypes.Schema
        An ibis schema corresponding to the types of the columns in `table`.
    """
    schema = schema if schema is not None else {}
    pairs = []
    for name, column in table.columns.items():
        if name in schema:
            dtype = dt.dtype(schema[name])
        else:
            dtype = dt.dtype(
                getattr(table.bind, 'dialect', SQLAlchemyDialect()),
                column.type,
                nullable=column.nullable,
            )
        pairs.append((name, dtype))
    return sch.schema(pairs)


def table_from_schema(name, meta, schema, database: Optional[str] = None):
    # Convert Ibis schema to SQLA table
    columns = []

    for colname, dtype in zip(schema.names, schema.types):
        satype = _to_sqla_type(dtype)
        column = sa.Column(colname, satype, nullable=dtype.nullable)
        columns.append(column)

    return sa.Table(name, meta, schema=database, *columns)


def _variance_reduction(func_name):
    suffix = {'sample': 'samp', 'pop': 'pop'}

    def variance_compiler(t, expr):
        arg, how, where = expr.op().args

        if arg.type().equals(dt.boolean):
            arg = arg.cast('int32')

        func = getattr(
            sa.func, '{}_{}'.format(func_name, suffix.get(how, 'samp'))
        )

        if where is not None:
            arg = where.ifelse(arg, None)
        return func(t.translate(arg))

    return variance_compiler


def infix_op(infix_sym):
    def formatter(t, expr):
        op = expr.op()
        left, right = op.args

        left_arg = t.translate(left)
        right_arg = t.translate(right)
        return left_arg.op(infix_sym)(right_arg)

    return formatter


def fixed_arity(sa_func, arity):
    if isinstance(sa_func, str):
        sa_func = getattr(sa.func, sa_func)

    def formatter(t, expr):
        if arity != len(expr.op().args):
            raise com.IbisError('incorrect number of args')

        return _varargs_call(sa_func, t, expr)

    return formatter


def varargs(sa_func):
    def formatter(t, expr):
        op = expr.op()
        trans_args = [t.translate(arg) for arg in op.arg]
        return sa_func(*trans_args)

    return formatter


def _varargs_call(sa_func, t, expr):
    op = expr.op()
    trans_args = [t.translate(arg) for arg in op.args]
    return sa_func(*trans_args)


def _table_column(t, expr):
    op = expr.op()
    ctx = t.context
    table = op.table

    sa_table = _get_sqla_table(ctx, table)
    out_expr = getattr(sa_table.c, op.name)

    # If the column does not originate from the table set in the current SELECT
    # context, we should format as a subquery
    if t.permit_subquery and ctx.is_foreign_expr(table):
        return sa.select([out_expr])

    return out_expr


def _get_sqla_table(ctx, table):
    if ctx.has_ref(table):
        ctx_level = ctx
        sa_table = ctx_level.get_table(table)
        while sa_table is None and ctx_level.parent is not ctx_level:
            ctx_level = ctx_level.parent
            sa_table = ctx_level.get_table(table)
    else:
        op = table.op()
        if isinstance(op, AlchemyTable):
            sa_table = op.sqla_table
        else:
            sa_table = ctx.get_compiled_expr(table)

    return sa_table


def _table_array_view(t, expr):
    ctx = t.context
    table = ctx.get_compiled_expr(expr.op().table)
    return table


def _exists_subquery(t, expr):
    op = expr.op()
    ctx = t.context

    filtered = op.foreign_table.filter(op.predicates).projection(
        [ir.literal(1).name(ir.unnamed)]
    )

    sub_ctx = ctx.subcontext()
    clause = to_sqlalchemy(filtered, sub_ctx, exists=True)

    if isinstance(op, transforms.NotExistsSubquery):
        clause = sa.not_(clause)

    return clause


def _cast(t, expr):
    op = expr.op()
    arg, target_type = op.args
    sa_arg = t.translate(arg)
    sa_type = t.get_sqla_type(target_type)

    if isinstance(arg, ir.CategoryValue) and target_type == 'int32':
        return sa_arg
    else:
        return sa.cast(sa_arg, sa_type)


def _contains(t, expr):
    op = expr.op()

    left, right = [t.translate(arg) for arg in op.args]

    return left.in_(right)


def _not_contains(t, expr):
    return sa.not_(_contains(t, expr))


def _reduction(sa_func):
    def formatter(t, expr):
        op = expr.op()
        *args, where = op.args

        return _reduction_format(t, sa_func, where, *args)

    return formatter


def _reduction_format(t, sa_func, where, arg, *args):
    if where is not None:
        arg = t.translate(where.ifelse(arg, ibis.NA))
    else:
        arg = t.translate(arg)

    return sa_func(arg, *map(t.translate, args))


def _literal(t, expr):
    dtype = expr.type()
    value = expr.op().value

    if isinstance(dtype, dt.Set):
        return list(map(sa.literal, value))

    return sa.literal(value)


def _value_list(t, expr):
    return [t.translate(x) for x in expr.op().values]


def _is_null(t, expr):
    arg = t.translate(expr.op().args[0])
    return arg.is_(sa.null())


def _not_null(t, expr):
    arg = t.translate(expr.op().args[0])
    return arg.isnot(sa.null())


def _round(t, expr):
    op = expr.op()
    arg, digits = op.args
    sa_arg = t.translate(arg)

    f = sa.func.round

    if digits is not None:
        sa_digits = t.translate(digits)
        return f(sa_arg, sa_digits)
    else:
        return f(sa_arg)


def _floor_divide(t, expr):
    left, right = map(t.translate, expr.op().args)
    return sa.func.floor(left / right)


def _count_distinct(t, expr):
    arg, where = expr.op().args

    if where is not None:
        sa_arg = t.translate(where.ifelse(arg, None))
    else:
        sa_arg = t.translate(arg)

    return sa.func.count(sa_arg.distinct())


def _simple_case(t, expr):
    op = expr.op()

    cases = [op.base == case for case in op.cases]
    return _translate_case(t, cases, op.results, op.default)


def _searched_case(t, expr):
    op = expr.op()
    return _translate_case(t, op.cases, op.results, op.default)


def _translate_case(t, cases, results, default):
    case_args = [t.translate(arg) for arg in cases]
    result_args = [t.translate(arg) for arg in results]

    whens = zip(case_args, result_args)
    default = t.translate(default)

    return sa.case(whens, else_=default)


def _negate(t, expr):
    op = expr.op()
    (arg,) = map(t.translate, op.args)
    return sa.not_(arg) if isinstance(expr, ir.BooleanValue) else -arg


def unary(sa_func):
    return fixed_arity(sa_func, 1)


def _string_like(t, expr):
    arg, pattern, escape = expr.op().args
    result = t.translate(arg).like(t.translate(pattern), escape=escape)
    return result


_cumulative_to_reduction = {
    ops.CumulativeSum: ops.Sum,
    ops.CumulativeMin: ops.Min,
    ops.CumulativeMax: ops.Max,
    ops.CumulativeMean: ops.Mean,
    ops.CumulativeAny: ops.Any,
    ops.CumulativeAll: ops.All,
}


def _cumulative_to_window(translator, expr, window):
    win = W.cumulative_window()
    win = win.group_by(window._group_by).order_by(window._order_by)

    op = expr.op()

    klass = _cumulative_to_reduction[type(op)]
    new_op = klass(*op.args)
    new_expr = expr._factory(new_op, name=expr._name)

    if type(new_op) in translator._rewrites:
        new_expr = translator._rewrites[type(new_op)](new_expr)

    return L.windowize_function(new_expr, win)


def _window(t, expr):
    op = expr.op()

    arg, window = op.args
    reduction = t.translate(arg)

    window_op = arg.op()

    _require_order_by = (
        ops.DenseRank,
        ops.MinRank,
        ops.NTile,
        ops.PercentRank,
    )

    if isinstance(window_op, ops.CumulativeOp):
        arg = _cumulative_to_window(t, arg, window)
        return t.translate(arg)

    if window.max_lookback is not None:
        raise NotImplementedError(
            'Rows with max lookback is not implemented '
            'for SQLAlchemy-based backends.'
        )

    # Some analytic functions need to have the expression of interest in
    # the ORDER BY part of the window clause
    if isinstance(window_op, _require_order_by) and not window._order_by:
        order_by = t.translate(window_op.args[0])
    else:
        order_by = list(map(t.translate, window._order_by))

    partition_by = list(map(t.translate, window._group_by))

    frame_clause_not_allowed = (
        ops.Lag,
        ops.Lead,
        ops.DenseRank,
        ops.MinRank,
        ops.NTile,
        ops.PercentRank,
        ops.RowNumber,
    )

    how = {'range': 'range_'}.get(window.how, window.how)
    preceding = window.preceding
    additional_params = (
        {}
        if isinstance(window_op, frame_clause_not_allowed)
        else {
            how: (
                -preceding if preceding is not None else preceding,
                window.following,
            )
        }
    )
    result = reduction.over(
        partition_by=partition_by, order_by=order_by, **additional_params
    )

    if isinstance(
        window_op, (ops.RowNumber, ops.DenseRank, ops.MinRank, ops.NTile)
    ):
        return result - 1
    else:
        return result


def _lag(t, expr):
    arg, offset, default = expr.op().args
    if default is not None:
        raise NotImplementedError()

    sa_arg = t.translate(arg)
    sa_offset = t.translate(offset) if offset is not None else 1
    return sa.func.lag(sa_arg, sa_offset)


def _lead(t, expr):
    arg, offset, default = expr.op().args
    if default is not None:
        raise NotImplementedError()
    sa_arg = t.translate(arg)
    sa_offset = t.translate(offset) if offset is not None else 1
    return sa.func.lead(sa_arg, sa_offset)


def _ntile(t, expr):
    op = expr.op()
    args = op.args
    arg, buckets = map(t.translate, args)
    return sa.func.ntile(buckets)


_operation_registry = {
    ops.And: fixed_arity(sql.and_, 2),
    ops.Or: fixed_arity(sql.or_, 2),
    ops.Not: unary(sa.not_),
    ops.Abs: unary(sa.func.abs),
    ops.Cast: _cast,
    ops.Coalesce: varargs(sa.func.coalesce),
    ops.NullIf: fixed_arity(sa.func.nullif, 2),
    ops.Contains: _contains,
    ops.NotContains: _not_contains,
    ops.Count: _reduction(sa.func.count),
    ops.Sum: _reduction(sa.func.sum),
    ops.Mean: _reduction(sa.func.avg),
    ops.Min: _reduction(sa.func.min),
    ops.Max: _reduction(sa.func.max),
    ops.CountDistinct: _count_distinct,
    ops.GroupConcat: _reduction(sa.func.group_concat),
    ops.Between: fixed_arity(sa.between, 3),
    ops.IsNull: _is_null,
    ops.NotNull: _not_null,
    ops.Negate: _negate,
    ops.Round: _round,
    ops.TypeOf: unary(sa.func.typeof),
    ops.Literal: _literal,
    ops.ValueList: _value_list,
    ops.NullLiteral: lambda *args: sa.null(),
    ops.SimpleCase: _simple_case,
    ops.SearchedCase: _searched_case,
    ops.TableColumn: _table_column,
    ops.TableArrayView: _table_array_view,
    transforms.ExistsSubquery: _exists_subquery,
    transforms.NotExistsSubquery: _exists_subquery,
    # miscellaneous varargs
    ops.Least: varargs(sa.func.least),
    ops.Greatest: varargs(sa.func.greatest),
    # string
    ops.LPad: fixed_arity(sa.func.lpad, 3),
    ops.RPad: fixed_arity(sa.func.rpad, 3),
    ops.Strip: unary(sa.func.trim),
    ops.LStrip: unary(sa.func.ltrim),
    ops.RStrip: unary(sa.func.rtrim),
    ops.Repeat: fixed_arity(sa.func.repeat, 2),
    ops.Reverse: unary(sa.func.reverse),
    ops.StrRight: fixed_arity(sa.func.right, 2),
    ops.Lowercase: unary(sa.func.lower),
    ops.Uppercase: unary(sa.func.upper),
    ops.StringAscii: unary(sa.func.ascii),
    ops.StringLength: unary(sa.func.length),
    ops.StringReplace: fixed_arity(sa.func.replace, 3),
    ops.StringSQLLike: _string_like,
    # math
    ops.Ln: unary(sa.func.ln),
    ops.Exp: unary(sa.func.exp),
    ops.Sign: unary(sa.func.sign),
    ops.Sqrt: unary(sa.func.sqrt),
    ops.Ceil: unary(sa.func.ceil),
    ops.Floor: unary(sa.func.floor),
    ops.Power: fixed_arity(sa.func.pow, 2),
    ops.FloorDivide: _floor_divide,
}


# TODO: unit tests for each of these
_binary_ops = {
    # Binary arithmetic
    ops.Add: operator.add,
    ops.Subtract: operator.sub,
    ops.Multiply: operator.mul,
    ops.Divide: operator.truediv,
    ops.Modulus: operator.mod,
    # Comparisons
    ops.Equals: operator.eq,
    ops.NotEquals: operator.ne,
    ops.Less: operator.lt,
    ops.LessEqual: operator.le,
    ops.Greater: operator.gt,
    ops.GreaterEqual: operator.ge,
    ops.IdenticalTo: lambda x, y: x.op('IS NOT DISTINCT FROM')(y),
    # Boolean comparisons
    # TODO
}


_window_functions = {
    ops.Lag: _lag,
    ops.Lead: _lead,
    ops.NTile: _ntile,
    ops.FirstValue: unary(sa.func.first_value),
    ops.LastValue: unary(sa.func.last_value),
    ops.RowNumber: fixed_arity(lambda: sa.func.row_number(), 0),
    ops.DenseRank: unary(lambda arg: sa.func.dense_rank()),
    ops.MinRank: unary(lambda arg: sa.func.rank()),
    ops.PercentRank: unary(lambda arg: sa.func.percent_rank()),
    ops.WindowOp: _window,
    ops.CumulativeOp: _window,
    ops.CumulativeMax: unary(sa.func.max),
    ops.CumulativeMin: unary(sa.func.min),
    ops.CumulativeSum: unary(sa.func.sum),
    ops.CumulativeMean: unary(sa.func.avg),
}

if geospatial_supported:
    _geospatial_functions = {
        ops.GeoArea: unary(sa.func.ST_Area),
        ops.GeoAsBinary: unary(sa.func.ST_AsBinary),
        ops.GeoAsEWKB: unary(sa.func.ST_AsEWKB),
        ops.GeoAsEWKT: unary(sa.func.ST_AsEWKT),
        ops.GeoAsText: unary(sa.func.ST_AsText),
        ops.GeoAzimuth: fixed_arity(sa.func.ST_Azimuth, 2),
        ops.GeoBuffer: fixed_arity(sa.func.ST_Buffer, 2),
        ops.GeoCentroid: unary(sa.func.ST_Centroid),
        ops.GeoContains: fixed_arity(sa.func.ST_Contains, 2),
        ops.GeoContainsProperly: fixed_arity(sa.func.ST_Contains, 2),
        ops.GeoCovers: fixed_arity(sa.func.ST_Covers, 2),
        ops.GeoCoveredBy: fixed_arity(sa.func.ST_CoveredBy, 2),
        ops.GeoCrosses: fixed_arity(sa.func.ST_Crosses, 2),
        ops.GeoDFullyWithin: fixed_arity(sa.func.ST_DFullyWithin, 3),
        ops.GeoDifference: fixed_arity(sa.func.ST_Difference, 2),
        ops.GeoDisjoint: fixed_arity(sa.func.ST_Disjoint, 2),
        ops.GeoDistance: fixed_arity(sa.func.ST_Distance, 2),
        ops.GeoDWithin: fixed_arity(sa.func.ST_DWithin, 3),
        ops.GeoEnvelope: unary(sa.func.ST_Envelope),
        ops.GeoEquals: fixed_arity(sa.func.ST_Equals, 2),
        ops.GeoGeometryN: fixed_arity(sa.func.ST_GeometryN, 2),
        ops.GeoGeometryType: unary(sa.func.ST_GeometryType),
        ops.GeoIntersection: fixed_arity(sa.func.ST_Intersection, 2),
        ops.GeoIntersects: fixed_arity(sa.func.ST_Intersects, 2),
        ops.GeoIsValid: unary(sa.func.ST_IsValid),
        ops.GeoLineLocatePoint: fixed_arity(sa.func.ST_LineLocatePoint, 2),
        ops.GeoLineMerge: unary(sa.func.ST_LineMerge),
        ops.GeoLineSubstring: fixed_arity(sa.func.ST_LineSubstring, 3),
        ops.GeoLength: unary(sa.func.ST_Length),
        ops.GeoNPoints: unary(sa.func.ST_NPoints),
        ops.GeoOrderingEquals: fixed_arity(sa.func.ST_OrderingEquals, 2),
        ops.GeoOverlaps: fixed_arity(sa.func.ST_Overlaps, 2),
        ops.GeoPerimeter: unary(sa.func.ST_Perimeter),
        ops.GeoSimplify: fixed_arity(sa.func.ST_Simplify, 3),
        ops.GeoSRID: unary(sa.func.ST_SRID),
        ops.GeoSetSRID: fixed_arity(sa.func.ST_SetSRID, 2),
        ops.GeoTouches: fixed_arity(sa.func.ST_Touches, 2),
        ops.GeoTransform: fixed_arity(sa.func.ST_Transform, 2),
        ops.GeoUnaryUnion: unary(sa.func.ST_Union),
        ops.GeoUnion: fixed_arity(sa.func.ST_Union, 2),
        ops.GeoWithin: fixed_arity(sa.func.ST_Within, 2),
        ops.GeoX: unary(sa.func.ST_X),
        ops.GeoY: unary(sa.func.ST_Y),
        # Missing Geospatial ops:
        #   ST_AsGML
        #   ST_AsGeoJSON
        #   ST_AsKML
        #   ST_AsRaster
        #   ST_AsSVG
        #   ST_AsTWKB
        #   ST_Distance_Sphere
        #   ST_Dump
        #   ST_DumpPoints
        #   ST_GeogFromText
        #   ST_GeomFromEWKB
        #   ST_GeomFromEWKT
        #   ST_GeomFromText
    }

    _operation_registry.update(_geospatial_functions)


for _k, _v in _binary_ops.items():
    _operation_registry[_k] = fixed_arity(_v, 2)


class AlchemySelectBuilder(comp.SelectBuilder):
    @property
    def _select_class(self):
        return AlchemySelect

    def _convert_group_by(self, exprs):
        return exprs


class AlchemyContext(comp.QueryContext):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._table_objects = {}

    def collapse(self, queries):
        if isinstance(queries, str):
            return queries

        if len(queries) > 1:
            raise NotImplementedError(
                'Only a single query is supported for SQLAlchemy backends'
            )
        return queries[0]

    def subcontext(self):
        return type(self)(
            dialect=self.dialect, parent=self, params=self.params
        )

    def _to_sql(self, expr, ctx):
        return to_sqlalchemy(expr, ctx)

    def _compile_subquery(self, expr):
        sub_ctx = self.subcontext()
        return self._to_sql(expr, sub_ctx)

    def has_table(self, expr, parent_contexts=False):
        key = self._get_table_key(expr)
        return self._key_in(
            key, '_table_objects', parent_contexts=parent_contexts
        )

    def set_table(self, expr, obj):
        key = self._get_table_key(expr)
        self._table_objects[key] = obj

    def get_table(self, expr):
        """
        Get the memoized SQLAlchemy expression object
        """
        return self._get_table_item('_table_objects', expr)


class AlchemyUnion(Union):
    def compile(self):
        def reduce_union(left, right, distincts=iter(self.distincts)):
            distinct = next(distincts)
            sa_func = sa.union if distinct else sa.union_all
            return sa_func(left, right)

        context = self.context
        selects = []

        for table in self.tables:
            table_set = context.get_compiled_expr(table)
            selects.append(table_set.cte().select())

        return functools.reduce(reduce_union, selects)


class AlchemyQueryBuilder(comp.QueryBuilder):

    select_builder = AlchemySelectBuilder
    union_class = AlchemyUnion


def to_sqlalchemy(expr, context, exists=False):
    ast = build_ast(expr, context)
    query = ast.queries[0]

    if exists:
        query.exists = exists

    return query.compile()


def build_ast(expr, context):
    builder = AlchemyQueryBuilder(expr, context)
    return builder.get_result()


class AlchemyDatabaseSchema(Database):
    def __init__(self, name, database):
        """

        Parameters
        ----------
        name : str
        database : AlchemyDatabase
        """
        self.name = name
        self.database = database
        self.client = database.client

    def __repr__(self):
        return "Schema({!r})".format(self.name)

    def drop(self, force=False):
        """
        Drop the schema

        Parameters
        ----------
        force : boolean, default False
          Drop any objects if they exist, and do not fail if the schema does
          not exist.
        """
        raise NotImplementedError(
            "Drop is not implemented yet for sqlalchemy schemas"
        )

    def table(self, name):
        """
        Return a table expression referencing a table in this schema

        Returns
        -------
        table : TableExpr
        """
        qualified_name = self._qualify(name)
        return self.database.table(qualified_name, self.name)

    def list_tables(self, like=None):
        return self.database.list_tables(
            schema=self.name, like=self._qualify_like(like)
        )


class AlchemyDatabase(Database):
    """

    Attributes
    ----------
    client : AlchemyClient

    """

    schema_class = AlchemyDatabaseSchema

    def table(self, name, schema=None):
        return self.client.table(name, schema=schema)

    def list_tables(self, like=None, schema=None):
        return self.client.list_tables(
            schema=schema, like=self._qualify_like(like), database=self.name
        )

    def schema(self, name):
        return self.schema_class(name, self)


class AlchemyTable(ops.DatabaseTable):
    def __init__(self, table, source, schema=None):
        schema = sch.infer(table, schema=schema)
        super().__init__(table.name, schema, source)
        self.sqla_table = table


class AlchemyExprTranslator(comp.ExprTranslator):

    _registry = _operation_registry
    _rewrites = comp.ExprTranslator._rewrites.copy()
    _type_map = _ibis_type_to_sqla

    context_class = AlchemyContext

    def name(self, translated, name, force=True):
        if hasattr(translated, 'label'):
            return translated.label(name)
        return translated

    def get_sqla_type(self, data_type):
        return _to_sqla_type(data_type, type_map=self._type_map)


rewrites = AlchemyExprTranslator.rewrites
compiles = AlchemyExprTranslator.compiles


class AlchemyQuery(Query):
    def _fetch(self, cursor):
        df = pd.DataFrame.from_records(
            cursor.proxy.fetchall(),
            columns=cursor.proxy.keys(),
            coerce_float=True,
        )
        schema = self.schema()
        return _maybe_to_geodataframe(schema.apply_to(df), schema)


class AlchemyDialect(Dialect):

    translator = AlchemyExprTranslator


def invalidates_reflection_cache(f):
    """Invalidate the SQLAlchemy reflection cache if `f` performs an operation
    that mutates database or table metadata such as ``CREATE TABLE``,
    ``DROP TABLE``, etc.

    Parameters
    ----------
    f : callable
        A method on :class:`ibis.sql.alchemy.AlchemyClient`
    """

    @functools.wraps(f)
    def wrapped(self, *args, **kwargs):
        result = f(self, *args, **kwargs)

        # only invalidate the cache after we've succesfully called the wrapped
        # function
        self._reflection_cache_is_dirty = True
        return result

    return wrapped


class AlchemyClient(SQLClient):

    dialect = AlchemyDialect
    query_class = AlchemyQuery
    has_attachment = False

    def __init__(self, con: sa.engine.Engine) -> None:
        super().__init__()
        self.con = con
        self.meta = sa.MetaData(bind=con)
        self._inspector = sa.inspect(con)
        self._reflection_cache_is_dirty = False
        self._schemas = {}

    @property
    def inspector(self):
        if self._reflection_cache_is_dirty:
            self._inspector.info_cache.clear()
        return self._inspector

    @contextlib.contextmanager
    def begin(self):
        with self.con.begin() as bind:
            yield bind

    @invalidates_reflection_cache
    def create_table(self, name, expr=None, schema=None, database=None):
        if database == self.database_name:
            # avoid fully qualified name
            database = None

        if database is not None:
            raise NotImplementedError(
                'Creating tables from a different database is not yet '
                'implemented'
            )

        if expr is None and schema is None:
            raise ValueError('You must pass either an expression or a schema')

        if expr is not None and schema is not None:
            if not expr.schema().equals(ibis.schema(schema)):
                raise TypeError(
                    'Expression schema is not equal to passed schema. '
                    'Try passing the expression without the schema'
                )
        if schema is None:
            schema = expr.schema()

        self._schemas[self._fully_qualified_name(name, database)] = schema
        t = self._table_from_schema(
            name, schema, database=database or self.current_database
        )

        with self.begin() as bind:
            t.create(bind=bind)
            if expr is not None:
                bind.execute(
                    t.insert().from_select(list(expr.columns), expr.compile())
                )

    def _columns_from_schema(
        self, name: str, schema: sch.Schema
    ) -> List[sa.Column]:
        return [
            sa.Column(colname, _to_sqla_type(dtype), nullable=dtype.nullable)
            for colname, dtype in zip(schema.names, schema.types)
        ]

    def _table_from_schema(
        self, name: str, schema: sch.Schema, database: Optional[str] = None
    ) -> sa.Table:
        columns = self._columns_from_schema(name, schema)
        return sa.Table(name, self.meta, *columns)

    @invalidates_reflection_cache
    def drop_table(
        self,
        table_name: str,
        database: Optional[str] = None,
        force: bool = False,
    ) -> None:
        if database == self.database_name:
            # avoid fully qualified name
            database = None

        if database is not None:
            raise NotImplementedError(
                'Dropping tables from a different database is not yet '
                'implemented'
            )

        t = self._get_sqla_table(table_name, schema=database, autoload=False)
        t.drop(checkfirst=force)

        assert (
            not t.exists()
        ), 'Something went wrong during DROP of table {!r}'.format(t.name)

        self.meta.remove(t)

        qualified_name = self._fully_qualified_name(table_name, database)

        try:
            del self._schemas[qualified_name]
        except KeyError:  # schemas won't be cached if created with raw_sql
            pass

    def load_data(
        self,
        table_name: str,
        data: pd.DataFrame,
        database: str = None,
        if_exists: str = 'fail',
    ):
        """
        Load data from a dataframe to the backend.

        Parameters
        ----------
        table_name : string
        data: pandas.DataFrame
        database : string, optional
        if_exists : string, optional, default 'fail'
            The values available are: {‘fail’, ‘replace’, ‘append’}

        Raises
        ------
        NotImplementedError
            Loading data to a table from a different database is not
            yet implemented
        """
        if database == self.database_name:
            # avoid fully qualified name
            database = None

        if database is not None:
            raise NotImplementedError(
                'Loading data to a table from a different database is not '
                'yet implemented'
            )

        params = {}
        if self.has_attachment:
            # for database with attachment
            # see: https://github.com/ibis-project/ibis/issues/1930
            params['schema'] = self.database_name

        data.to_sql(
            table_name,
            con=self.con,
            index=False,
            if_exists=if_exists,
            **params,
        )

    def truncate_table(
        self, table_name: str, database: Optional[str] = None
    ) -> None:
        t = self._get_sqla_table(table_name, schema=database)
        t.delete().execute()

    def list_tables(
        self,
        like: Optional[str] = None,
        database: Optional[str] = None,
        schema: Optional[str] = None,
    ) -> List[str]:
        """List tables/views in the current or indicated database.

        Parameters
        ----------
        like
            Checks for this string contained in name
        database
            If not passed, uses the current database
        schema
            The schema namespace that tables should be listed from

        Returns
        -------
        List[str]

        """
        inspector = self.inspector
        # inspector returns a mutable version of its names, so make a copy.
        names = inspector.get_table_names(schema=schema).copy()
        names.extend(inspector.get_view_names(schema=schema))
        if like is not None:
            names = [x for x in names if like in x]
        return sorted(names)

    def _execute(self, query: str, results: bool = True):
        return AlchemyProxy(self.con.execute(query))

    @invalidates_reflection_cache
    def raw_sql(self, query: str, results: bool = False):
        return super().raw_sql(query, results=results)

    def _build_ast(self, expr, context):
        return build_ast(expr, context)

    def _get_sqla_table(self, name, schema=None, autoload=True):
        return sa.Table(name, self.meta, schema=schema, autoload=autoload)

    def _sqla_table_to_expr(self, table):
        node = self.table_class(table, self)
        return self.table_expr_class(node)

    @property
    def version(self):
        vstring = '.'.join(map(str, self.con.dialect.server_version_info))
        return parse_version(vstring)


class AlchemySelect(Select):
    def __init__(self, *args, **kwargs):
        self.exists = kwargs.pop('exists', False)
        super().__init__(*args, **kwargs)

    def compile(self):
        # Can't tell if this is a hack or not. Revisit later
        self.context.set_query(self)

        self._compile_subqueries()

        frag = self._compile_table_set()
        steps = [
            self._add_select,
            self._add_groupby,
            self._add_where,
            self._add_order_by,
            self._add_limit,
        ]

        for step in steps:
            frag = step(frag)

        return frag

    def _compile_subqueries(self):
        if not self.subqueries:
            return

        for expr in self.subqueries:
            result = self.context.get_compiled_expr(expr)
            alias = self.context.get_ref(expr)
            result = result.cte(alias)
            self.context.set_table(expr, result)

    def _compile_table_set(self):
        if self.table_set is not None:
            helper = _AlchemyTableSet(self, self.table_set)
            return helper.get_result()
        else:
            return None

    def _add_select(self, table_set):
        to_select = []

        has_select_star = False
        for expr in self.select_set:
            if isinstance(expr, ir.ValueExpr):
                arg = self._translate(expr, named=True)
            elif isinstance(expr, ir.TableExpr):
                if expr.equals(self.table_set):
                    cached_table = self.context.get_table(expr)
                    if cached_table is None:
                        # the select * case from materialized join
                        has_select_star = True
                        continue
                    else:
                        arg = table_set
                else:
                    arg = self.context.get_table(expr)
                    if arg is None:
                        raise ValueError(expr)

            to_select.append(arg)

        if has_select_star:
            if table_set is None:
                raise ValueError('table_set cannot be None here')

            clauses = [table_set] + to_select
        else:
            clauses = to_select

        if self.exists:
            result = sa.exists(clauses)
        else:
            result = sa.select(clauses)

        if self.distinct:
            result = result.distinct()

        if not has_select_star:
            if table_set is not None:
                return result.select_from(table_set)
            else:
                return result
        else:
            return result

    def _add_groupby(self, fragment):
        # GROUP BY and HAVING
        if not len(self.group_by):
            return fragment

        group_keys = [self._translate(arg) for arg in self.group_by]
        fragment = fragment.group_by(*group_keys)

        if len(self.having) > 0:
            having_args = [self._translate(arg) for arg in self.having]
            having_clause = functools.reduce(sql.and_, having_args)
            fragment = fragment.having(having_clause)

        return fragment

    def _add_where(self, fragment):
        if not len(self.where):
            return fragment

        args = [
            self._translate(pred, permit_subquery=True) for pred in self.where
        ]
        clause = functools.reduce(sql.and_, args)
        return fragment.where(clause)

    def _add_order_by(self, fragment):
        if not len(self.order_by):
            return fragment

        clauses = []
        for expr in self.order_by:
            key = expr.op()
            sort_expr = key.expr

            # here we have to determine if key.expr is in the select set (as it
            # will be in the case of order_by fused with an aggregation
            if _can_lower_sort_column(self.table_set, sort_expr):
                arg = sort_expr.get_name()
            else:
                arg = self._translate(sort_expr)

            if not key.ascending:
                arg = sa.desc(arg)

            clauses.append(arg)

        return fragment.order_by(*clauses)

    def _among_select_set(self, expr):
        for other in self.select_set:
            if expr.equals(other):
                return True
        return False

    def _add_limit(self, fragment):
        if self.limit is None:
            return fragment

        n, offset = self.limit['n'], self.limit['offset']
        fragment = fragment.limit(n)
        if offset is not None and offset != 0:
            fragment = fragment.offset(offset)

        return fragment

    @property
    def dialect(self):
        return self.context.dialect


class _AlchemyTableSet(TableSetFormatter):
    def get_result(self):
        # Got to unravel the join stack; the nesting order could be
        # arbitrary, so we do a depth first search and push the join tokens
        # and predicates onto a flat list, then format them
        op = self.expr.op()

        if isinstance(op, ops.Join):
            self._walk_join_tree(op)
        else:
            self.join_tables.append(self._format_table(self.expr))

        result = self.join_tables[0]
        for jtype, table, preds in zip(
            self.join_types, self.join_tables[1:], self.join_predicates
        ):
            if len(preds):
                sqla_preds = [self._translate(pred) for pred in preds]
                onclause = functools.reduce(sql.and_, sqla_preds)
            else:
                onclause = None

            if jtype in (ops.InnerJoin, ops.CrossJoin):
                result = result.join(table, onclause)
            elif jtype is ops.LeftJoin:
                result = result.join(table, onclause, isouter=True)
            elif jtype is ops.RightJoin:
                result = table.join(result, onclause, isouter=True)
            elif jtype is ops.OuterJoin:
                result = result.outerjoin(table, onclause, full=True)
            elif jtype is ops.LeftSemiJoin:
                result = sa.select([result]).where(
                    sa.exists(sa.select([1]).where(onclause))
                )
            elif jtype is ops.LeftAntiJoin:
                result = sa.select([result]).where(
                    ~(sa.exists(sa.select([1]).where(onclause)))
                )
            else:
                raise NotImplementedError(jtype)

        return result

    def _get_join_type(self, op):
        return type(op)

    def _format_table(self, expr):
        ctx = self.context
        ref_expr = expr
        op = ref_op = expr.op()

        if isinstance(op, ops.SelfReference):
            ref_expr = op.table
            ref_op = ref_expr.op()

        alias = ctx.get_ref(expr)

        if isinstance(ref_op, AlchemyTable):
            result = ref_op.sqla_table
        elif isinstance(ref_op, ops.UnboundTable):
            # use SQLAlchemy's TableClause and ColumnClause for unbound tables
            schema = ref_op.schema
            result = sa.table(
                ref_op.name if ref_op.name is not None else ctx.get_ref(expr),
                *(
                    sa.column(n, _to_sqla_type(t))
                    for n, t in zip(schema.names, schema.types)
                ),
            )
        else:
            # A subquery
            if ctx.is_extracted(ref_expr):
                # Was put elsewhere, e.g. WITH block, we just need to grab
                # its alias
                alias = ctx.get_ref(expr)

                # hack
                if isinstance(op, ops.SelfReference):
                    table = ctx.get_table(ref_expr)
                    self_ref = table.alias(alias)
                    ctx.set_table(expr, self_ref)
                    return self_ref
                else:
                    return ctx.get_table(expr)

            result = ctx.get_compiled_expr(expr)
            alias = ctx.get_ref(expr)

        result = result.alias(alias)
        ctx.set_table(expr, result)
        return result


def _can_lower_sort_column(table_set, expr):
    # TODO(wesm): This code is pending removal through cleaner internal
    # semantics

    # we can currently sort by just-appeared aggregate metrics, but the way
    # these are references in the expression DSL is as a SortBy (blocking
    # table operation) on an aggregation. There's a hack in _collect_SortBy
    # in the generic SQL compiler that "fuses" the sort with the
    # aggregation so they appear in same query. It's generally for
    # cosmetics and doesn't really affect query semantics.
    bases = ops.find_all_base_tables(expr)
    if len(bases) > 1:
        return False

    base = list(bases.values())[0]
    base_op = base.op()

    if isinstance(base_op, ops.Aggregation):
        return base_op.table.equals(table_set)
    elif isinstance(base_op, ops.Selection):
        return base.equals(table_set)
    else:
        return False


class AlchemyProxy:
    """
    Wraps a SQLAlchemy ResultProxy and ensures that .close() is called on
    garbage collection
    """

    def __init__(self, proxy):
        self.proxy = proxy

    def __del__(self):
        self._close_cursor()

    def _close_cursor(self):
        self.proxy.close()

    def __enter__(self):
        return self

    def __exit__(self, type, value, tb):
        self._close_cursor()

    def fetchall(self):
        return self.proxy.fetchall()


@rewrites(ops.NullIfZero)
def _nullifzero(expr):
    arg = expr.op().args[0]
    return (arg == 0).ifelse(ibis.NA, arg)


@compiles(ops.Divide)
def _true_divide(t, expr):
    op = expr.op()
    left, right = args = op.args

    if util.all_of(args, ir.IntegerValue):
        return t.translate(left.div(right.cast('double')))

    return fixed_arity(lambda x, y: x / y, 2)(t, expr)


@compiles(ops.SortKey)
def _sort_key(t, expr):
    # We need to define this for window functions that have an order by
    by, ascending = expr.op().args
    sort_direction = sa.asc if ascending else sa.desc
    return sort_direction(t.translate(by))


def _maybe_to_geodataframe(df, schema):
    """
    If the required libraries for geospatial support are installed, and if a
    geospatial column is present in the dataframe, convert it to a
    GeoDataFrame.
    """

    def to_shapely(row, name):
        return shape.to_shape(row[name]) if row[name] is not None else None

    if len(df) and geospatial_supported:
        geom_col = None
        for name, dtype in schema.items():
            if isinstance(dtype, dt.GeoSpatial):
                geom_col = geom_col or name
                df[name] = df.apply(lambda x: to_shapely(x, name), axis=1)
        if geom_col:
            df = geopandas.GeoDataFrame(df, geometry=geom_col)
    return df