# -*- coding:utf-8 -*-
import logging
from collections import OrderedDict
import sqlalchemy.types as t
from sqlalchemy.inspection import inspect
from sqlalchemy.orm.properties import ColumnProperty
from sqlalchemy.orm.relationships import RelationshipProperty
from sqlalchemy.sql.visitors import VisitableType
from sqlalchemy.orm.base import ONETOMANY, MANYTOONE, MANYTOMANY
from sqlalchemy.sql.type_api import TypeEngine

logger = logging.getLogger(__name__)

EMPTY_DICT = {}


class InvalidStatus(Exception):
    pass


"""
http://json-schema.org/latest/json-schema-core.html#anchor8
3.5.  JSON Schema primitive types

JSON Schema defines seven primitive types for JSON values:

    array
        A JSON array.
    boolean
        A JSON boolean.
    integer
        A JSON number without a fraction or exponent part.
    number
        Any JSON number. Number includes integer.
    null
        The JSON null value.
    object
        A JSON object.
    string
        A JSON string.
"""

#  tentative
default_column_to_schema = {
    t.String: "string",
    t.Text: "string",
    t.Integer: "integer",
    t.SmallInteger: "integer",
    t.BigInteger: "string",  # xxx
    t.Numeric: "integer",
    t.Float: "number",
    t.DateTime: "string",
    t.Date: "string",
    t.Time: "string",  # xxx
    t.LargeBinary: "xxx",
    t.Binary: "xxx",
    t.Boolean: "boolean",
    t.Unicode: "string",
    t.Concatenable: "xxx",
    t.UnicodeText: "string",
    t.Interval: "xxx",
    t.Enum: "string",
}


# restriction
def string_max_length(column, sub):
    if column.type.length is not None:
        sub["maxLength"] = column.type.length


def enum_one_of(column, sub):
    sub["enum"] = list(column.type.enums)


def datetime_format(column, sub):
    sub["format"] = "date-time"


def date_format(column, sub):
    sub["format"] = "date"


def time_format(column, sub):
    sub["format"] = "time"


default_restriction_dict = {
    t.String: string_max_length,
    t.Enum: enum_one_of,
    t.DateTime: datetime_format,
    t.Date: date_format,
    t.Time: time_format,
}


class Classifier(object):
    def __init__(self, mapping=default_column_to_schema, see_mro=True, see_impl=True):
        self.mapping = mapping
        self.see_mro = see_mro
        self.see_impl = see_impl

    def __getitem__(self, k):
        cls = k.__class__
        _, mapped = get_class_mapping(
            self.mapping, cls, see_mro=self.see_mro, see_impl=self.see_impl
        )
        if mapped is None:
            raise InvalidStatus("notfound: {k}. (cls={cls})".format(k=k, cls=cls))
        return cls, mapped


def get_class_mapping(mapping, cls, see_mro=True, see_impl=True):
    v = mapping.get(cls)
    if v is not None:
        return cls, v

    # inheritance
    if see_mro:
        for type_ in cls.mro()[1:]:
            if type_ is TypeEngine:
                break
            if type_ in mapping:
                return type_, mapping[type_]

    # type decorator
    if see_impl and hasattr(cls, "impl"):
        impl = cls.impl
        if not callable(impl):
            # If the class level impl is not a callable (the unusual case),
            impl = impl.__class__
        return get_class_mapping(mapping, impl, see_mro=see_mro, see_impl=see_impl)
    return None, None


DefaultClassfier = Classifier(default_column_to_schema)
Empty = ()


class BaseModelWalker(object):
    def __init__(self, model, includes=None, excludes=None, history=None):
        self.mapper = inspect(model).mapper
        self.includes = includes
        self.excludes = excludes
        self.history = history or []
        if includes and excludes:
            if set(includes).intersection(excludes):
                raise InvalidStatus(
                    "Conflict includes={}, exclude={}".format(includes, excludes)
                )

    def clone(self, name, mapper, includes, excludes, history):
        return self.__class__(mapper, includes, excludes, history)

    def from_child(self, model):
        return self.__class__(model, history=self.history)


# mapper.column_attrs and mapper.attrs is not ordered. define our custom iterate function `iterate'


class ForeignKeyWalker(BaseModelWalker):
    def iterate(self):
        for c in self.mapper.local_table.columns:
            yield self.mapper._props[c.name]  # danger!! not immutable

    def walk(self):
        for prop in self.iterate():
            if self.includes is None or prop.key in self.includes:
                if self.excludes is None or prop.key not in self.excludes:
                    yield prop


class NoForeignKeyWalker(BaseModelWalker):
    def iterate(self):
        for c in self.mapper.local_table.columns:
            yield self.mapper._props[c.name]  # danger!! not immutable

    def walk(self):
        for prop in self.iterate():
            if self.includes is None or prop.key in self.includes:
                if self.excludes is None or prop.key not in self.excludes:
                    if not any(c.foreign_keys for c in getattr(prop, "columns", Empty)):
                        yield prop


class StructuralWalker(BaseModelWalker):
    def iterate(self):
        # self.mapper.attrs
        for c in self.mapper.local_table.columns:
            yield self.mapper._props[c.name]  # danger!! not immutable
        for prop in self.mapper.relationships:
            yield prop

    def walk(self):
        for prop in self.iterate():
            if isinstance(prop, (ColumnProperty, RelationshipProperty)):
                if self.includes is None or prop.key in self.includes:
                    if self.excludes is None or prop.key not in self.excludes:
                        if prop not in self.history:
                            if not any(
                                c.foreign_keys for c in getattr(prop, "columns", Empty)
                            ):
                                yield prop


def get_children(name, params, splitter=".", default=None):  # todo: rename
    prefix = name + splitter
    if hasattr(params, "items"):
        return {
            k.split(splitter, 1)[1]: v
            for k, v in params.items()
            if k.startswith(prefix)
        }
    elif isinstance(params, (list, tuple)):
        return [e.split(splitter, 1)[1] for e in params if e.startswith(prefix)]
    else:
        return default


pop_marker = object()


class CollectionForOverrides(object):
    def __init__(self, params, pop_marker=pop_marker):
        self.params = params or {}
        self.not_used_keys = set(params.keys())
        self.pop_marker = pop_marker

    def __contains__(self, k):
        return k in self.params

    def overrides(self, basedict):
        for k, v in self.params.items():
            if v == self.pop_marker:
                basedict.pop(k)  # xxx: KeyError?
            else:
                basedict[k] = v
            self.not_used_keys.remove(k)  # xxx: KeyError?


class ChildFactory(object):
    def __init__(self, splitter=".", bidirectional=False):
        self.splitter = splitter
        self.bidirectional = bidirectional

    def default_excludes(self, prop):
        return [prop.back_populates, prop.backref]

    def child_overrides(self, prop, overrides):
        name = prop.key
        children = get_children(name, overrides.params, splitter=self.splitter)
        return overrides.__class__(children, pop_marker=overrides.pop_marker)

    def child_walker(self, prop, walker, history=None):
        name = prop.key
        excludes = get_children(
            name, walker.includes, splitter=self.splitter, default=[]
        )
        if not self.bidirectional:
            excludes.extend(self.default_excludes(prop))
        includes = get_children(name, walker.includes, splitter=self.splitter)

        return walker.clone(
            name, prop.mapper, includes=includes, excludes=excludes, history=history
        )

    def child_schema(
        self, prop, schema_factory, root_schema, walker, overrides, depth, history
    ):
        subschema = schema_factory._build_properties(
            walker,
            root_schema,
            overrides,
            depth=(depth and depth - 1),
            history=history,
            toplevel=False,
        )
        if prop.direction == ONETOMANY:
            return {"type": "array", "items": subschema}
        else:
            return {"type": "object", "properties": subschema}


RELATIONSHIP = "relationship"
FOREIGNKEY = "foreignkey"
IMMEDIATE = "immediate"


class RelationDesicion(object):
    def desicion(self, walker, prop, toplevel):
        if hasattr(prop, "mapper"):
            yield RELATIONSHIP, prop, EMPTY_DICT
        elif hasattr(prop, "columns"):
            yield FOREIGNKEY, prop, EMPTY_DICT
        else:
            raise NotImplemented(prop)


class UseForeignKeyIfPossibleDecision(object):
    def desicion(self, walker, prop, toplevel):
        if hasattr(prop, "mapper"):
            if prop.direction == MANYTOONE:
                if toplevel:
                    for c in prop.local_columns:
                        yield FOREIGNKEY, walker.mapper._props[c.name], {
                            "relation": prop.key
                        }
                else:
                    rp = walker.history[0]
                    if prop.local_columns != rp.remote_side:
                        for c in prop.local_columns:
                            yield FOREIGNKEY, walker.mapper._props[c.name], {
                                "relation": prop.key
                            }
            elif prop.direction == MANYTOMANY:
                # logger.warn("skip mapper=%s, prop=%s is many to many.", walker.mapper, prop)
                yield {"type": "array", "items": {"type": "string"}}, prop, EMPTY_DICT
            else:
                yield RELATIONSHIP, prop, EMPTY_DICT
        elif hasattr(prop, "columns"):
            yield FOREIGNKEY, prop, EMPTY_DICT
        else:
            raise NotImplemented(prop)


class SchemaFactory(object):
    def __init__(
        self,
        walker,
        classifier=DefaultClassfier,
        restriction_dict=default_restriction_dict,
        container_factory=OrderedDict,
        child_factory=ChildFactory("."),
        relation_decision=RelationDesicion(),
    ):
        self.container_factory = container_factory
        self.classifier = classifier
        self.walker = walker  # class
        self.restriction_set = [{k: v} for k, v in restriction_dict.items()]
        self.child_factory = child_factory
        self.relation_decision = relation_decision

    def __call__(self, model, includes=None, excludes=None, overrides=None, depth=None):
        walker = self.walker(model, includes=includes, excludes=excludes)
        overrides = CollectionForOverrides(overrides or {})

        schema = {"title": model.__name__, "type": "object"}
        schema["properties"] = self._build_properties(
            walker, schema, overrides=overrides, depth=depth
        )

        if overrides.not_used_keys:
            raise InvalidStatus("invalid overrides: {}".format(overrides.not_used_keys))

        if model.__doc__:
            schema["description"] = model.__doc__

        required = self._detect_required(walker)

        if required:
            schema["required"] = required
        return schema

    def _add_restriction_if_found(self, D, column, itype):
        for restriction_dict in self.restriction_set:
            _, fn = get_class_mapping(
                restriction_dict,
                itype,
                see_impl=self.classifier.see_impl,
                see_mro=self.classifier.see_mro,
            )
            if fn is not None:
                if isinstance(fn, (list, tuple)):
                    for f in fn:
                        f(column, D)
                else:
                    fn(column, D)

    def _add_property_with_reference(
        self, walker, root_schema, current_schema, prop, val
    ):
        clsname = prop.mapper.class_.__name__
        if "definitions" not in root_schema:
            root_schema["definitions"] = {}

        if val["type"] == "object":
            current_schema[prop.key] = {"$ref": "#/definitions/{}".format(clsname)}
            val["required"] = self._detect_required(walker.from_child(prop.mapper))
            root_schema["definitions"][clsname] = val
        else:  # array
            current_schema[prop.key] = {
                "type": "array",
                "items": {"$ref": "#/definitions/{}".format(clsname)},
            }
            val["type"] = "object"
            val["properties"] = val.pop("items")
            val["required"] = self._detect_required(walker.from_child(prop.mapper))
            root_schema["definitions"][clsname] = val

    def _build_properties(
        self, walker, root_schema, overrides, depth=None, history=None, toplevel=True
    ):
        if depth is not None and depth <= 0:
            return self.container_factory()

        D = self.container_factory()
        if history is None:
            history = []

        for prop in walker.walk():
            for action, prop, opts in self.relation_decision.desicion(
                walker, prop, toplevel
            ):
                if action == RELATIONSHIP:  # RelationshipProperty
                    history.append(prop)
                    subwalker = self.child_factory.child_walker(
                        prop, walker, history=history
                    )
                    suboverrides = self.child_factory.child_overrides(prop, overrides)
                    value = self.child_factory.child_schema(
                        prop,
                        self,
                        root_schema,
                        subwalker,
                        suboverrides,
                        depth=depth,
                        history=history,
                    )
                    self._add_property_with_reference(
                        walker, root_schema, D, prop, value
                    )
                    history.pop()
                elif action == FOREIGNKEY:  # ColumnProperty
                    for c in prop.columns:
                        sub = {}
                        if type(c.type) != VisitableType:
                            itype, sub["type"] = self.classifier[c.type]

                            self._add_restriction_if_found(sub, c, itype)

                            if c.doc:
                                sub["description"] = c.doc

                            if c.name in overrides:
                                overrides.overrides(sub)
                            if opts:
                                sub.update(opts)
                            D[c.name] = sub
                        else:
                            raise NotImplemented
                    D[prop.key] = sub
                else:  # immediate
                    D[prop.key] = action
        return D

    def _detect_required(self, walker):
        r = []
        for prop in walker.walk():
            columns = getattr(prop, "columns", Empty)
            if any(not c.nullable and c.default is None for c in columns):
                r.append(prop.key)
        return r