# -*- coding: utf-8 -*-
""" Functions for database operations.

    cls : class
        The ORM mapped class.
Note:
    1) SQLAlchemy exceptions are not handled in the following functions (
    except IntegrityError when inserting), so when using these functions,
    you should handle exceptions by yourself.
    2) To avoid concurrency conflict, especially when inserting URLs (we
    have several applications that insert URL into database simultaneously,
    including twitter streaming and scrapy crawling), we commit data into
    database as soon as possible.
    3) We do not delete objects in the database, so please set
    expire_on_commit=False whenusing session. Otherwise, sqlalchemy will
    re-fetch the orm object if you access it again after commit.
"""
#
# written by Chengcheng Shao <sccotte@gmail.com>

from hoaxy.database import ENGINE
from hoaxy.database.models import AssUrlPlatform, Platform
from hoaxy.database.models import Url, Site, AlternateDomain, SiteTag
from hoaxy.database.models import TwitterUserUnion
from sqlalchemy import and_, func, or_
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import joinedload, load_only
import logging
import sqlalchemy
import sqlparse

logger = logging.getLogger(__name__)


def get_m(session,
          cls,
          fb_kw=None,
          f_expr=None,
          ob_expr=None,
          limit=None,
          options=None):
    """Query ORMs.

    Parameters
    ----------
    session : object
        An instance of SQLAlchemy Session.
    cls : class
        The ORM model class.
    fb_kw : dict
        The filter_by keywords, used by query.filter_by().
    f_expr : list
        The filter expression, used by query.filter().
    ob_expr : object
        The order by expression, used by query.order_by().
    limit : int
        The limit expression, used by query.limit().
    options : list
        Other query options, e.g, load, join options, used by query.options().

    Returns
    -------
    list
        A list of queried ORM objects.
    """
    q = session.query(Site)
    if fb_kw:
        q = q.filter_by(**fb_kw)
    if f_expr:
        q = q.filter(*f_expr)
    if ob_expr is not None:
        q = q.order_by(ob_expr)
    if options:
        q = q.options(*options)
    if limit:
        q = q.limit(limit)
    return q.all()


def get_max(session, col, fb_kw=None, f_expr=None):
    """Get the maximum value of one specified column.

    Parameters
    ----------
    session : object
        An instance of SQLAlchemy Session.
    col : object
        A column object.
    fb_kw : dict
        The filter_by keywords, used by query.filter_by().
    f_expr : list
        The filter expressions, used by query.filter().

    Returns
    -------
    int
        The maximum value of `col`, if no records match the query, `None` is
        returned.
    """
    q = session.query(func.max(col))
    if fb_kw:
        q = q.filter_by(**fb_kw)
    if f_expr:
        q = q.filter(*f_expr)
    return q.scalar()


def get_msites(session,
               fb_kw=dict(is_enabled=True),
               f_expr=None,
               ob_expr=None,
               limit=None,
               options=[
                   joinedload(Site.alternate_domains),
                   joinedload(Site.site_tags)
               ]):
    """A quick query to fetch sites with frequently used attributes. Returned
    as ORM object.

    Parameters
    ----------
    session : object
        An instance of SQLAlchemy Session.
    fb_kw : dict
        The filter_by keywords, user by query.filter_by().
    f_expr : list
        A list of filter expressions, used by query.filter().
    ob_expr : object
        Order by expression, used by query.order_by().
    limit : int
        Limit expression, used by query.limit().
    options : list
        query options, used by query.options().

    Returns
    -------
    list
        A list of Site objects.
    """
    return get_m(session, Site, fb_kw, f_expr, ob_expr, limit, options)


def get_site_tuples(session,
                    fb_kw=dict(is_enabled=True),
                    f_expr=None,
                    ob_expr=None,
                    limit=None,
                    options=[
                        joinedload(Site.alternate_domains),
                        joinedload(Site.site_tags)
                    ]):
    """A quick query to fetch sites with frequently used attributes. Returned
    as tuple.

    Parameters
    ----------
    session : object
        An instance of SQLAlchemy Session.
    fb_kw : dict
        The filter_by keywords, user by query.filter_by().
    f_expr : list
        A list of filter expressions, used by query.filter().
    ob_expr : object
        Order by expression, used by query.order_by().
    limit : int
        Limit expression, used by query.limit().
    options : list
        query options, used by query.options().

    Returns
    -------
    list
        A list of tuple (id, domain).
    """
    r = []
    for ms in get_msites(
            session,
            fb_kw=fb_kw,
            f_expr=f_expr,
            ob_expr=ob_expr,
            limit=limit,
            options=options):
        r.append((ms.id, ms.domain))
        for d in ms.alternate_domains:
            r.append((ms.id, d.name))
    return r


def create_m(session, cls, data):
    """ Insert an orm object into database.

    Parameters
    ----------
    session : object
        An instance of SQLAlchemy Session.
    cls : class
        The ORM mapped class.
    data : dict
        A dict that contains necessary attributes of the ORM objects.

    Notes
    -----
    This is a fast method to insert a record into one table, when
    duplicate happens, the record will be ignored.

    No ORM object is returned!
    """
    session.add(cls(**data))
    try:
        session.commit()
    # Already in db
    except IntegrityError as e:
        logger.debug(e)
        session.rollback()


def create_or_get_m(session, cls, data, fb_uk):
    """Try to insert an record into table, if exist, return it.

    This function first try to insert the record, if fail because of
    duplications. then query it.

    Parameters
    ----------
    session : object
        An instance of SQLAlchemy Session.
    cls : class
        The ORM mapped class.
    data : dict
        A dict that contains necessary attributes of the ORM objects.
    fb_uk : string or list
        The unique columns to identify records in the table.

    Returns
    -------
    object
        The created model object.

    """
    m = cls(**data)
    session.add(m)
    try:
        session.commit()
        return m
    except IntegrityError as e:
        logger.debug(e)
        session.rollback()
        q = session.query(cls)
        # set up filter_by()
        fb = dict()
        if isinstance(fb_uk, str):
            fb[fb_uk] = data[fb_uk]
        elif isinstance(fb_uk, (list, tuple)):
            for k in fb_uk:
                fb[k] = data[k]
        q = q.filter_by(**fb)
        return q.one()


def get_or_create_m(session,
                    cls,
                    data,
                    fb_uk=None,
                    fb_kws=None,
                    f_expr=None,
                    onduplicate='ignore',
                    load_cols=None):
    """Try to get one record from table, if not exist, try to insert it.

    This function first try to insert the record, if fail because of
    duplications. then query it.

    Parameters
    ----------
    session : object
        An instance of SQLAlchemy Session.
    cls : class
        The ORM mapped class.
    data : dict
        A dict that contains necessary attributes of the ORM objects.
    fb_uk : string or list
        The unique columns to identify records in the table.
    fb_kw : dict
        The filter_by keywords, user by query.filter_by().
    f_expr : list
        A list of filter expressions, used by query.filter().
    onduplicate : {'ignore', 'update'}
        Handle when duplication happens. If 'ignore', then ignore it. If
        'update', then update the record according the `data`.
    load_cols : list
        A list of columns objects that will be loaded. This is very useful if
        you would like to load only frequently used columns.

    Returns
    -------
    object
        The created or existed model object.

    """
    q = session.query(cls)
    # set up filter_by()
    if fb_uk:
        fb = dict()
        if isinstance(fb_uk, str):
            fb[fb_uk] = data[fb_uk]
        elif isinstance(fb_uk, (list, tuple)):
            for k in fb_uk:
                fb[k] = data[k]
        q = q.filter_by(**fb)
    if fb_kws:
        q = q.filter_by(**fb_kws)
    # set up filter()
    if f_expr:
        q = q.filter(*f_expr)
    if load_cols:
        q = q.options(load_only(*load_cols))
    mobj = q.one_or_none()
    if mobj:
        if onduplicate == 'update':
            q.update(data)
            session.commit()
    else:
        mobj = cls(**data)
        session.add(mobj)
        try:
            session.commit()
        except IntegrityError as e:
            logger.warning('Concurrency error %s!', e)
            session.rollback()
            mobj = q.one()
    return mobj


def append_platform_to_url(session, url_id, platform_id):
    """Set platform_id for a URL record.

    The relationship between table url and platform is M:N, so there is a
    association table called ass_url_platform to connect them.

    Parameters
    ----------
    session : object
        An instance of SQLAlchemy Session.
    url_id : int
        The id of a URL record.
    platform_id : int
        The id of a platform record.
    """
    if session.query(AssUrlPlatform.id).filter_by(
            url_id=url_id, platform_id=platform_id).scalar() is None:
        session.add(AssUrlPlatform(url_id=url_id, platform_id=platform_id))
        try:
            session.commit()
        except IntegrityError as e:
            logger.warning('Error Concurrecy conflict %s', e)
            session.rollback()


def get_or_create_murl(session,
                       data,
                       platform_id=None,
                       load_cols=['id', 'date_published']):
    """Get a URL record from table, if not exists, insert it.

    The function is similar as `get_or_create_m`. The difference is how to
    handle duplications. In this function, try to update 'date_published'
    if `data['date_published']` is not None.

    Parameters
    ----------
    session : object
        An instance of SQLAlchemy Session.
    data : dict
        A dict that contains necessary attributes of the ORM objects.
    platform_id : int
        The id of a platform object.
    load_cols : list
        The columns to be loaded. Default is ['id', 'date_published'].

    Returns
    -------
    object
        A URL model object.
    """
    q = session.query(Url).filter_by(raw=data['raw'])\
        .options(load_only(*load_cols))
    murl = q.one_or_none()
    if murl:
        # update date_published if possible
        if murl.date_published is None and \
                data.get('date_published', None) is not None:
            murl.date_published = data['date_published']
            session.commit()
    else:
        murl = Url(**data)
        session.add(murl)
        try:
            session.commit()
        except IntegrityError as e:
            logger.warning('Concurrecy conflict %s', e)
            session.rollback()
            murl = q.one()
    if platform_id is not None:
        append_platform_to_url(session, murl.id, platform_id)
    return murl


def qquery_msite(session, name=None, domain=None):
    """A quick way to query site by its name or domain.

    Parameters
    ----------
    session : object
        An instance of SQLAlchemy Session.
    name : string
        The name of the site to query.
    domain : string
        The primary domain of the site to query.

    Returns
    -------
    object
        The Site ORM object.
    """
    if name is None and domain is None:
        raise TypeError('name or domain are required!')
    q = session.query(Site)
    if name is not None:
        q = q.filter_by(name=name)
    else:
        q = q.filter_by(domain=domain)
    q = q.options(
        joinedload(Site.alternate_domains), joinedload(Site.site_tags))
    return q.one_or_none()


def get_or_create_msite(session,
                        site,
                        alternate_domains=[],
                        site_tags=[],
                        onduplicate='update'):
    """Get a site from table, if not exist, insert into table.

    This function take cares of the alternate_domains relationi (1:M) and
    site_tags relation (M:N).

    Parameters
    ----------
    session : object
        An instance of SQLAlchemy Session.
    site : dict
        The site data.
    alternate_domains : list
        A list of alternate domain dict with keys 'name' and 'is_alive'.
    site_tags : list
        A list of site tag dict with keys 'name' and 'source'
    onduplicate : {'ignore', 'update'}
        How to handle duplication.
    """
    # FIRST, check whether this site exist
    q = session.query(Site).filter(
        or_(Site.domain.like(site['domain']), Site.name.like(site[
            'name']))).options(
                joinedload(Site.alternate_domains), joinedload(Site.site_tags))
    msite = q.one_or_none()
    if msite is None:
        msite = Site(**site)
        for d in alternate_domains:
            mad = get_or_create_m(session, AlternateDomain, d, fb_uk='name')
            msite.alternate_domains.append(mad)
        for t in site_tags:
            mtag = get_or_create_m(
                session, SiteTag, t, fb_uk=['name', 'source'])
            msite.site_tags.append(mtag)
        session.add(msite)
    elif onduplicate == 'update':
        # UPDATE site
        session.query(Site).filter_by(id=msite.id).update(site)
        adding_domains = [d['name'] for d in alternate_domains]
        owned_domains = [mad.name for mad in msite.alternate_domains]
        # HANDLE ALTERNATE DOMAINS
        # delete non-need ones
        for mad in msite.alternate_domains:
            if mad.name not in adding_domains:
                session.delete(mad)
        # add new ones
        for d in alternate_domains:
            if d['name'] not in owned_domains:
                session.add(AlternateDomain(site_id=msite.id, **d))
        # HANDLE SITE TAGS
        adding_tags = [(t['name'], t['source']) for t in site_tags]
        owned_tags = [(mt.name, mt.source) for mt in msite.site_tags]
        # delete non-need ones
        for mt in msite.site_tags:
            if (mt.name, mt.source) not in adding_tags:
                session.delete(mt)
        # add new ones
        for t in site_tags:
            if (t['name'], t['source']) not in owned_tags:
                mtag = get_or_create_m(
                    session, SiteTag, t, fb_uk=['name', 'source'])
                msite.site_tags.append(mtag)
    try:
        session.commit()
    except IntegrityError as e:
        logger.exception(e)
        session.rollback()


def create_or_update_muser(session, data):
    muser = TwitterUserUnion(**data)
    try:
        session.add(muser)
        session.commit()
    except IntegrityError as e:
        logger.debug('Error %s', e)
        session.rollback()
        raw_id = data.pop('raw_id')
        session.query(TwitterUserUnion).filter_by(raw_id=raw_id).update(data)
        session.commit()
        data['raw_id'] = raw_id


def convert_to_sqlalchemy_statement(raw_sql_script):
    """Convert raw SQL into SQLAlchemy statement."""
    # remove comment and tail spaces
    formated_sql_script = sqlparse.format(
        raw_sql_script.strip(), strip_comments=True)
    return sqlparse.split(formated_sql_script)


def migrate_db(sql_script):
    """Migrate database using sql_script."""
    logger.info('Start migration %r', __file__)
    with ENGINE.connect() as conn:
        conn.autocommit = True
        conn.execution_options(isolation_level='AUTOCOMMIT')
        for stmt in convert_to_sqlalchemy_statement(sql_script):
            logger.debug('Executing sql statement %r', stmt)
            conn.execute(stmt)
    logger.info('Migration finished!')


def column_windows(session, w_column, w_size, fb_kw=None, f_expr=None):
    """Return a series of WHERE clauses against a given column that break it
    into windows.

    Parameters
    ----------
    session : object
        An instance of SQLAlchemy Session.
    w_column : object
        Column object that is used to split into windows, should be an
        integer column.
    w_size : int
        Size of the window
    fb_kw : dict
        The filter_by keywords, used by query.filter_by().
    f_expr : list
        The filter expressions, used by query.filter().

    Returns
    -------
    iterable
        Each element of the iterable is a whereclause expression, which
        specify the range of the window over the column `w_col`.

    Exmaple
    -------
    for whereclause in column_windows(q.session, w_column, w_size):
        for row in q.filter(whereclause).order_by(w_column):
            yield row

    """

    def int_for_range(start_id, end_id):
        """Internal function to build range."""
        if end_id:
            return and_(w_column >= start_id, w_column < end_id)
        else:
            return w_column >= start_id

    q = session.query(
        w_column, func.row_number().over(order_by=w_column).label('w_row_num'))
    if fb_kw:
        q = q.filter_by(**fb_kw)
    if f_expr:
        q = q.filter(*f_expr)
    q = q.from_self(w_column)
    if w_size > 1:
        q = q.filter(sqlalchemy.text("w_row_num % {}=1".format(w_size)))

    intervals = [id for id, in q]

    while intervals:
        start = intervals.pop(0)
        if intervals:
            end = intervals[0]
        else:
            end = None
        yield int_for_range(start, end)


def get_platform_id(session, name):
    """A quick query to get the platform id by its name."""
    return session.query(Platform.id).filter_by(name=name).scalar()