"""Classes and supporting functions to manipulate ordering columns and extract
keyset markers from query results."""
from copy import copy
from warnings import warn

import sqlalchemy
from sqlalchemy import asc, column
from sqlalchemy.orm import Bundle, Mapper, class_mapper
from sqlalchemy.orm.attributes import QueryableAttribute
from sqlalchemy.sql.elements import _label_reference
from sqlalchemy.sql.expression import ClauseList, ColumnElement, Label
from sqlalchemy.sql.operators import (asc_op, desc_op, nullsfirst_op,
                                      nullslast_op)

_LABELLED = (Label, _label_reference)
_ORDER_MODIFIERS = (asc_op, desc_op, nullsfirst_op, nullslast_op)
_UNSUPPORTED_ORDER_MODIFIERS = (nullsfirst_op, nullslast_op)
_WRAPPING_DEPTH = 1000
_WRAPPING_OVERFLOW = ("Maximum element wrapping depth reached; there's "
                      "probably a circularity in sqlalchemy that "
                      "sqlakeyset doesn't know how to handle.")


def parse_clause(clause):
    """Parse an ORDER BY clause into a list of :class:`OC` instances."""
    def _flatten(cl):
        if isinstance(cl, ClauseList):
            for subclause in cl.clauses:
                for x in _flatten(subclause):
                    yield x
        else:
            yield cl
    return [OC(c) for c in _flatten(clause)]


def _warn_if_nullable(x):
    try:
        if x.nullable or x.property.columns[0].nullable:
            warn(f"Ordering by nullable column {x} can cause rows to be "
                 "incorrectly omitted from the results. "
                 "See the sqlakeyset README for more details.")
    except (AttributeError, IndexError, KeyError):
        pass

class OC:
    """Wrapper class for ordering columns; i.e.  instances of
    :class:`sqlalchemy.sql.expression.ColumnElement` appearing in the ORDER BY
    clause of a query we are paging."""
    def __init__(self, x):
        if isinstance(x, str):
            x = column(x)
        if _get_order_direction(x) is None:
            x = asc(x)
        self.uo = x
        _warn_if_nullable(self.comparable_value)
        self.full_name = str(self.element)
        try:
            table_name, name = self.full_name.split('.', 1)
        except ValueError:
            table_name = None
            name = self.full_name

        self.table_name = table_name
        self.name = name

    @property
    def quoted_full_name(self):
        return str(self).split()[0]

    @property
    def element(self):
        """The ordering column/SQL expression with ordering modifier removed."""
        return _remove_order_direction(self.uo)

    @property
    def comparable_value(self):
        """The ordering column/SQL expression in a form that is suitable for
        incorporating in a ``ROW(...) > ROW(...)`` comparision; i.e. with ordering
        modifiers and labels removed."""
        return strip_labels(self.element)

    @property
    def is_ascending(self):
        """Returns ``True`` if this column is ascending, ``False`` if
        descending."""
        d = _get_order_direction(self.uo)
        if d is None:
            raise ValueError  # pragma: no cover
        return d == asc_op

    @property
    def reversed(self):
        """An :class:`OC` representing the same column ordering, but reversed."""
        new_uo = _reverse_order_direction(self.uo)
        if new_uo is None:
            raise ValueError # pragma: no cover
        return OC(new_uo)

    def __str__(self):
        return str(self.uo)

    def __repr__(self):
        return '<OC: {}>'.format(str(self))

def strip_labels(el):
    """Remove labels from a
    :class:`sqlalchemy.sql.expression.ColumnElement`."""
    while isinstance(el, _LABELLED):
        try:
            el = el.element
        except AttributeError:
            raise ValueError # pragma: no cover
    return el


def _get_order_direction(x):
    """
    Given a :class:`sqlalchemy.sql.expression.ColumnElement`, find and return
    its ordering direction (ASC or DESC) if it has one.

    :param x: a :class:`sqlalchemy.sql.expression.ColumnElement`
    :return: `asc_op`, `desc_op` or `None`
    """
    for _ in range(_WRAPPING_DEPTH):
        mod = getattr(x, 'modifier', None)
        if mod in (asc_op, desc_op):
            return mod

        el = getattr(x, 'element', None)
        if el is None:
            return None
        x = el
    raise Exception(_WRAPPING_OVERFLOW) # pragma: no cover


def _reverse_order_direction(ce):
    """
    Given a :class:`sqlalchemy.sql.expression.ColumnElement`, return a copy
    with its ordering direction (ASC or DESC) reversed (if it has one).

    :param ce: a :class:`sqlalchemy.sql.expression.ColumnElement`
    """
    x = copied = ce._clone()
    for _ in range(_WRAPPING_DEPTH):
        mod = getattr(x, 'modifier', None)
        if mod in (asc_op, desc_op):
            if mod == asc_op:
                x.modifier = desc_op
            else:
                x.modifier = asc_op
            return copied
        else:
            if not hasattr(x, 'element'):
                return copied
            # Since we're going to change something inside x.element, we
            # need to clone another level deeper.
            x._copy_internals()
            x = x.element
    raise Exception(_WRAPPING_OVERFLOW) # pragma: no cover


def _remove_order_direction(ce):
    """
    Given a :class:`sqlalchemy.sql.expression.ColumnElement`, return a copy
    with its ordering modifiers (ASC/DESC, NULLS FIRST/LAST) removed (if it has
    any).

    :param ce: a :class:`sqlalchemy.sql.expression.ColumnElement`
    """
    x = copied = ce._clone()
    parent = None
    for _ in range(_WRAPPING_DEPTH):
        mod = getattr(x, 'modifier', None)
        if mod in _UNSUPPORTED_ORDER_MODIFIERS:
            warn("One of your order columns had a NULLS FIRST or NULLS LAST "
                 "modifier; but sqlakeyset does not support order columns "
                 "with nulls. YOUR RESULTS WILL BE WRONG. See the "
                 "Limitations section of the sqlakeyset README for more "
                 "information.")
        if mod in _ORDER_MODIFIERS:
            x._copy_internals()
            if parent is None:
                # The modifier was at the top level; so just take the child.
                copied = x = x.element
            else:
                # Remove this link from the wrapping element chain and return
                # the top-level expression.
                parent.element = x = x.element
        else:
            if not hasattr(x, 'element'):
                return copied
            parent = x
            # Since we might change something inside x.element, we
            # need to clone another level deeper.
            x._copy_internals()
            x = x.element
    raise Exception(_WRAPPING_OVERFLOW) # pragma: no cover


class MappedOrderColumn:
    """An ordering column in the context of a particular query/select.

    This wraps an :class:`OC` with one extra piece of information: how to
    retrieve the value of the ordering key from a result row. For some queries,
    this requires adding extra entities to the query; in this case,
    ``extra_entity`` will be set."""

    def __init__(self, oc):
        self.oc = oc
        self.extra_entity = None
        """An extra SQLAlchemy ORM entity that this ordering column needs to
        add to its query in order to retrieve its value at each row. If no
        extra data is required, the value of this property will be ``None``."""

    def get_from_row(self, internal_row):
        """Extract the value of this ordering column from a result row."""
        raise NotImplementedError # pragma: no cover

    @property
    def ob_clause(self):
        """The original ORDER BY (sub)clause underlying this column."""
        return self.oc.uo

    @property
    def reversed(self):
        """A :class:`MappedOrderColumn` representing the same column in the
        reversed order."""
        c = copy(self)
        c.oc = c.oc.reversed
        return c

    def __str__(self):
        return str(self.oc)


class DirectColumn(MappedOrderColumn):
    """An ordering key that was directly included as a column in the original
    query."""
    def __init__(self, oc, index):
        super().__init__(oc)
        self.index = index

    def get_from_row(self, row):
        return row[self.index]

    def __repr__(self):
        return "Direct({}, {!r})".format(self.index, self.oc)


class AttributeColumn(MappedOrderColumn):
    """An ordering key that was included as a column attribute in the original
    query."""
    def __init__(self, oc, index, attr):
        super().__init__(oc)
        self.index = index
        self.attr = attr

    def get_from_row(self, row):
        return getattr(row[self.index], self.attr)

    def __repr__(self):
        return "Attribute({}.{}, {!r})".format(self.index, self.attr, self.oc)


class AppendedColumn(MappedOrderColumn):
    """An ordering key that requires an additional column to be added to the
    original query."""
    _counter = 0
    def __init__(self, oc, name=None):
        super().__init__(oc)
        if not name:
            AppendedColumn._counter += 1
            name = "_sqlakeyset_oc_{}".format(AppendedColumn._counter)
        self.name = name
        self.extra_entity = self.oc.comparable_value.label(self.name)

    def get_from_row(self, row):
        return getattr(row, self.name)

    @property
    def ob_clause(self):
        col = self.extra_entity
        return col if self.oc.is_ascending else col.desc()

    def __repr__(self):
        return "Appended({!r})".format(self.oc)

def derive_order_key(ocol, desc, index):
    """Attempt to derive the value of `ocol` from a query column.

    :param ocol: The :class:`OC` to look up.
    :param desc: Either a column description as in
        :attr:`sqlalchemy.orm.query.Query.column_descriptions`, or a
        :class:`sqlalchemy.sql.expression.ColumnElement`.

    :returns: Either a :class:`MappedOrderColumn` or `None`."""
    if isinstance(desc, ColumnElement):
        if desc.compare(ocol.comparable_value):
            return DirectColumn(ocol, index)
        else:
            return None

    entity = desc['entity']
    expr = desc['expr']

    if isinstance(expr, Bundle):
        for key, col in expr.columns.items():
            if strip_labels(col).compare(ocol.comparable_value):
                return AttributeColumn(ocol, index, key)

    try:
        is_a_table = bool(entity == expr)
    except (sqlalchemy.exc.ArgumentError, TypeError):
        is_a_table = False

    if isinstance(expr, Mapper) and expr.class_ == entity:
        is_a_table = True

    if is_a_table:  # is a table
        mapper = class_mapper(desc['type'])
        try:
            prop = mapper.get_property_by_column(ocol.element)
            return AttributeColumn(ocol, index, prop.key)
        except sqlalchemy.orm.exc.UnmappedColumnError:
            pass

    # is an attribute
    if isinstance(expr, QueryableAttribute):
        mapper = expr.parent
        tname = mapper.local_table.description
        if ocol.table_name == tname and ocol.name == expr.name:
            return DirectColumn(ocol, index)

    # is an attribute with label
    try:
        if ocol.quoted_full_name == OC(expr).full_name:
            return DirectColumn(ocol, index)
    except sqlalchemy.exc.ArgumentError:
        pass

def find_order_key(ocol, column_descriptions):
    """Return a :class:`MappedOrderColumn` describing how to populate the
    ordering column `ocol` from a query returning columns described by
    `column_descriptions`.

    :param ocol: The :class:`OC` to look up.
    :param column_descriptions: The list of columns from which to attempt to
        derive the value of `ocol`.
    :returns: A :class:`MappedOrderColumn` wrapping `ocol`."""
    for index, desc in enumerate(column_descriptions):
        ok = derive_order_key(ocol, desc, index)
        if ok is not None:
            return ok

    # Couldn't find an existing column in the query from which we can
    # determine this ordering column; so we need to add one.
    return AppendedColumn(ocol)