#!/usr/bin/env python
# -*- coding: utf-8 -*-
# **************************************************************************
# Copyright © 2017-2020 jianglin
# File Name: backends.py
# Author: jianglin
# Email: mail@honmaple.com
# Created: 2017-04-15 20:03:27 (CST)
# Last Update: Monday 2020-05-18 23:06:43 (CST)
#          By:
# Description:
# **************************************************************************
import logging

from flask_sqlalchemy import models_committed
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.inspection import inspect
from werkzeug.utils import import_string
from flask.helpers import locked_cached_property

from .signal import default_signal


def relation_column(instance, fields):
    '''
    such as: user.username
    such as: replies.content
    '''
    relation = getattr(instance.__class__, fields[0]).property
    _field = getattr(instance, fields[0])
    if relation.lazy == 'dynamic':
        _field = _field.first()
    return getattr(_field, fields[1]) if _field else ''


class BaseSchema(object):
    def __init__(self, index):
        self.index = index

    def _fields(self):
        return dict()

    @property
    def fields(self):
        model = self.index.model
        schema_fields = self._fields()
        primary_keys = [key.name for key in inspect(model).primary_key]

        schema = getattr(model, "__msearch_schema__", dict())
        for field in self.index.searchable:
            if '.' in field:
                fields = field.split('.')
                field_attr = getattr(
                    getattr(model, fields[0]).property.mapper.class_,
                    fields[1])
            else:
                field_attr = getattr(model, field)

            if field in schema:
                field_type = schema[field]
                if isinstance(field_type, str):
                    schema_fields[field] = self.fields_map(field_type)
                else:
                    schema_fields[field] = field_type
                continue

            if hasattr(field_attr, 'descriptor') and isinstance(
                    field_attr.descriptor, hybrid_property):
                schema_fields[field] = self.fields_map("text")
                continue

            if field in primary_keys:
                schema_fields[field] = self.fields_map("primary")
                continue

            field_type = field_attr.property.columns[0].type
            schema_fields[field] = self.fields_map(field_type)
        return schema_fields


class BaseBackend(object):
    def __init__(self, app=None, db=None, analyzer=None):
        """
        You can custom analyzer by::

            from jieba.analyse import ChineseAnalyzer
            search = Search(analyzer = ChineseAnalyzer)
        """
        self._signal = None
        self._indexs = dict()
        self.db = db
        self.analyzer = analyzer
        if app is not None:
            self.init_app(app)

    def _setdefault(self, app):
        app.config.setdefault("MSEARCH_PRIMARY_KEY", "id")
        app.config.setdefault("MSEARCH_INDEX_NAME", "msearch")
        app.config.setdefault("MSEARCH_INDEX_SIGNAL", default_signal)
        app.config.setdefault("MSEARCH_ANALYZER", None)
        app.config.setdefault("MSEARCH_ENABLE", True)
        app.config.setdefault("MSEARCH_LOGGER", logging.WARNING)

    def _signal_connect(self, app):
        if app.config["MSEARCH_ENABLE"]:
            signal = app.config["MSEARCH_INDEX_SIGNAL"]
            if isinstance(signal, str):
                self._signal = import_string(signal)
            else:
                self._signal = signal
            models_committed.connect(self.index_signal)

    @locked_cached_property
    def logger(self):
        logger = logging.getLogger(__name__)
        logger.addHandler(logging.StreamHandler())
        logger.setLevel(self.app.config["MSEARCH_LOGGER"])
        return logger

    def index_signal(self, sender, changes):
        return self._signal(self, sender, changes)

    def init_app(self, app):
        self.app = app
        if not self.db:
            self.db = self.app.extensions['sqlalchemy'].db
        self.db.Model.query_class = self._query_class(
            self.db.Model.query_class)

    def _query_class(self, q):
        _self = self

        class Query(q):
            def msearch(self, query, fields=None, limit=None, or_=False):
                model = self._mapper_zero().class_
                return _self.msearch(model, query, fields, limit, or_)

        return Query

    def create_index(self,
                     model='__all__',
                     update=False,
                     delete=False,
                     yield_per=100):
        if model == '__all__':
            return self.create_all_index(update, delete)
        ix = self.index(model)
        instances = model.query.enable_eagerloads(False).yield_per(yield_per)
        for instance in instances:
            self.create_one_index(instance, update, delete, False)
        ix.commit()
        return ix

    def create_all_index(self, update=False, delete=False, yield_per=100):
        all_models = self.db.Model._decl_class_registry.values()
        models = [i for i in all_models if hasattr(i, '__searchable__')]
        ixs = []
        for m in models:
            ix = self.create_index(m, update, delete, yield_per)
            ixs.append(ix)
        return ixs

    def update_one_index(self, instance, commit=True):
        return self.create_one_index(instance, update=True, commit=commit)

    def delete_one_index(self, instance, commit=True):
        return self.delete_one_index(instance, delete=True, commit=commit)

    def update_all_index(self, yield_per=100):
        return self.create_all_index(update=True, yield_per=yield_per)

    def delete_all_index(self, yield_per=100):
        return self.create_all_index(delete=True, yield_per=yield_per)

    def update_index(self, model='__all__', yield_per=100):
        return self.create_index(model, update=True, yield_per=yield_per)

    def delete_index(self, model='__all__', yield_per=100):
        return self.create_index(model, delete=True, yield_per=yield_per)

    def whoosh_search(self, m, query, fields=None, limit=None, or_=False):
        self.logger.warning(
            'whoosh_search has been replaced by msearch.please use msearch')
        return self.msearch(m, query, fields, limit, or_)

    # def msearch(self, m, query, fields=None, limit=None, or_=False):
    #     raise NotImplementedError