#
# Copyright 2018 Red Hat, Inc.
#
#    This program is free software: you can redistribute it and/or modify
#    it under the terms of the GNU Affero General Public License as
#    published by the Free Software Foundation, either version 3 of the
#    License, or (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU Affero General Public License for more details.
#
#    You should have received a copy of the GNU Affero General Public License
#    along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Database accessor for report data."""
import logging
import uuid
from decimal import Decimal
from decimal import InvalidOperation

import django.apps
from django.db import connection
from tenant_schemas.utils import schema_context

from masu.config import Config
from masu.database.koku_database_access import KokuDBAccess
from reporting_common import REPORT_COLUMN_MAP

LOG = logging.getLogger(__name__)


class ReportSchema:
    """A container for the reporting table objects."""

    def __init__(self, tables):
        """Initialize the report schema."""
        self.column_types = {}
        self._set_reporting_tables(tables)

    def _set_reporting_tables(self, models):
        """Load table objects for reference and creation.

        Args:
            report_schema (ReportSchema): A schema struct object with all
                report tables
        """
        column_types = {}
        for model in models:
            if "django" in model._meta.db_table:
                continue
            setattr(self, model._meta.db_table, model)
            columns = REPORT_COLUMN_MAP[model._meta.db_table].values()
            types = {column: model._meta.get_field(column).get_internal_type() for column in columns}
            column_types.update({model._meta.db_table: types})
            self.column_types = column_types


class ReportDBAccessorBase(KokuDBAccess):
    """Class to interact with customer reporting tables."""

    def __init__(self, schema):
        """Establish the database connection.

        Args:
            schema (str): The customer schema to associate with
        """
        super().__init__(schema)
        self.report_schema = ReportSchema(django.apps.apps.get_models())

    @property
    def decimal_precision(self):
        """Return database precision for decimal values."""
        return f"0E-{Config.REPORTING_DECIMAL_PRECISION}"

    def create_temp_table(self, table_name, drop_column=None):
        """Create a temporary table and return the table name."""
        temp_table_name = table_name + "_" + str(uuid.uuid4()).replace("-", "_")
        with connection.cursor() as cursor:
            cursor.db.set_schema(self.schema)
            cursor.execute(f"CREATE TEMPORARY TABLE {temp_table_name} (LIKE {table_name})")
            if drop_column:
                cursor.execute(f"ALTER TABLE {temp_table_name} DROP COLUMN {drop_column}")
        return temp_table_name

    def create_new_temp_table(self, table_name, columns):
        """Create a temporary table and return the table name."""
        temp_table_name = table_name + "_" + str(uuid.uuid4()).replace("-", "_")
        base_sql = f"CREATE TEMPORARY TABLE {temp_table_name} "
        column_types = ""
        for column in columns:
            for name, column_type in column.items():
                column_types += f"{name} {column_type}, "
        column_types = column_types.strip().rstrip(",")
        column_sql = f"({column_types})"
        table_sql = base_sql + column_sql
        with connection.cursor() as cursor:
            cursor.db.set_schema(self.schema)
            cursor.execute(table_sql)

        return temp_table_name

    def merge_temp_table(self, table_name, temp_table_name, columns, conflict_columns):
        """INSERT temp table rows into the primary table specified.

        Args:
            table_name (str): The main table to insert into
            temp_table_name (str): The temp table to pull from
            columns (list): A list of columns to use in the insert logic

        Returns:
            (None)

        """
        column_str = ",".join(columns)
        conflict_col_str = ",".join(conflict_columns)

        set_clause = ",".join([f"{column} = excluded.{column}" for column in columns])
        upsert_sql = f"""
            INSERT INTO {table_name} ({column_str})
                SELECT {column_str}
                FROM {temp_table_name}
                ON CONFLICT ({conflict_col_str}) DO UPDATE
                SET {set_clause}
            """
        with connection.cursor() as cursor:
            cursor.db.set_schema(self.schema)
            cursor.execute(upsert_sql)

            delete_sql = f"DELETE FROM {temp_table_name}"
            cursor.execute(delete_sql)

    def bulk_insert_rows(self, file_obj, table, columns, sep=","):
        """Insert many rows using Postgres copy functionality.

        Args:
            file_obj (file): A file-like object containing CSV rows
            table (str): The table name in the databse to copy to
            columns (list): A list of columns in the order of the CSV file
            sep (str): The separator in the file. Default: ','

        """
        columns = ", ".join(columns)
        with connection.cursor() as cursor:
            cursor.db.set_schema(self.schema)
            statement = f"COPY {table} ({columns}) FROM STDIN WITH CSV DELIMITER '{sep}'"
            cursor.copy_expert(statement, file_obj)

    def _get_db_obj_query(self, table, columns=None):
        """Return a query on a specific database table.

        Args:
            table (DjangoModel): Which table to query
            columns (list): A list of column names to exclusively return

        Returns:
            (Query): A query object

        """
        # If table is a str, get the model associated
        if isinstance(table, str):
            table = getattr(self.report_schema, table)

        with schema_context(self.schema):
            if columns:
                query = table.objects.values(*columns)
            else:
                query = table.objects.all()
            return query

    def create_db_object(self, table_name, data):
        """Instantiate a populated database object.

        Args:
            table_name (str): The name of the table to create
            data (dict): A dictionary of data to insert into the object

        Returns:
            (Table): A populated SQLAlchemy table object specified by table_name

        """
        table = getattr(self.report_schema, table_name)
        data = self.clean_data(data, table_name)

        with schema_context(self.schema):
            model_object = table(**data)
            model_object.save()
            return model_object

    def insert_on_conflict_do_nothing(self, table, data, conflict_columns=None):
        """Write an INSERT statement with an ON CONFLICT clause.

        This is useful to avoid duplicate row inserts. Intended for
        single row inserts.

        Args:
            table_name (str): The name of the table to insert into
            data (dict): A dictionary of data to insert into the object
            columns (list): A list of columns to check conflict on

        Returns:
            (str): The id of the inserted row

        """
        table_name = table()._meta.db_table
        data = self.clean_data(data, table_name)
        columns_formatted = ", ".join(str(value) for value in data.keys())
        values = list(data.values())
        val_str = ",".join(["%s" for _ in data])
        insert_sql = f"""
            INSERT INTO {self.schema}.{table_name}({columns_formatted}) VALUES({val_str})
            """
        if conflict_columns:
            conflict_columns_formatted = ", ".join(conflict_columns)
            insert_sql = insert_sql + f" ON CONFLICT ({conflict_columns_formatted}) DO NOTHING;"
        else:
            insert_sql = insert_sql + " ON CONFLICT DO NOTHING;"
        with connection.cursor() as cursor:
            cursor.db.set_schema(self.schema)
            cursor.execute(insert_sql, values)
        if conflict_columns:
            data = {key: value for key, value in data.items() if key in conflict_columns}

        return self._get_primary_key(table, data)

    def insert_on_conflict_do_update(self, table, data, conflict_columns, set_columns):
        """Write an INSERT statement with an ON CONFLICT clause.

        This is useful to update rows on insert. Intended for
        singl row inserts.

        Args:
            table_name (str): The name of the table to insert into
            data (dict): A dictionary of data to insert into the object
            conflict_columns (list): Columns to check conflict on
            set_columns (list): Columns to update

        Returns:
            (str): The id of the inserted row

        """
        table_name = table()._meta.db_table
        data = self.clean_data(data, table_name)

        set_clause = ",".join([f"{column} = excluded.{column}" for column in set_columns])

        columns_formatted = ", ".join(str(value) for value in data.keys())
        values = list(data.values())
        val_str = ",".join(["%s" for _ in data])
        conflict_columns_formatted = ", ".join(conflict_columns)

        insert_sql = f"""
        INSERT INTO {self.schema}.{table_name}({columns_formatted}) VALUES ({val_str})
         ON CONFLICT ({conflict_columns_formatted}) DO UPDATE SET
         {set_clause}
        """
        with connection.cursor() as cursor:
            cursor.db.set_schema(self.schema)
            cursor.execute(insert_sql, values)

        data = {key: value for key, value in data.items() if key in conflict_columns}

        return self._get_primary_key(table_name, data)

    def _get_primary_key(self, table_name, data):
        """Return the row id for a specific object."""
        with schema_context(self.schema):
            query = self._get_db_obj_query(table_name)
            query = query.filter(**data)
            try:
                row_id = query.first().id
            except AttributeError as err:
                LOG.error("Row in %s does not exist in database.", table_name)
                LOG.error("Failed row data: %s", data)
                raise err
            else:
                return row_id

    def clean_data(self, data, table_name):
        """Clean data for insertion into database.

        Args:
            data (dict): The data to be cleaned
            table_name (str): The table name the data is associated with

        Returns:
            (dict): The data with values converted to required types

        """
        column_types = self.report_schema.column_types[table_name]

        for key, value in data.items():
            if value is None or value == "":
                data[key] = None
                continue
            if column_types.get(key) == int or column_types.get(key) == "BigIntegerField":
                data[key] = self._convert_value(value, int)
            elif column_types.get(key) == float:
                data[key] = self._convert_value(value, float)
            elif column_types.get(key) == Decimal:
                data[key] = self._convert_value(value, Decimal)

        return data

    def _convert_value(self, value, column_type):
        """Convert a single value to the specified column type.

        Args:
            value (var): A value of any type
            column_type (type) A Python type

        Returns:
            (var): The variable converted to type or None if conversion fails.

        """
        if column_type == Decimal:
            try:
                value = Decimal(value).quantize(Decimal(self.decimal_precision))
            except InvalidOperation:
                value = None
        else:
            try:
                value = column_type(value)
            except ValueError as err:
                LOG.warning(err)
                value = None
        return value

    def _execute_raw_sql_query(self, table, sql, start=None, end=None, bind_params=None):
        """Run a SQL statement via a cursor."""
        if start and end:
            LOG.info("Updating %s from %s to %s.", table, start, end)
        else:
            LOG.info("Updating %s", table)

        with connection.cursor() as cursor:
            cursor.db.set_schema(self.schema)
            cursor.execute(sql, params=bind_params)
        LOG.info("Finished updating %s.", table)