"""Use to transfer a MySQL database to SQLite."""

from __future__ import division

import logging
import re
import sqlite3
import sys
from datetime import timedelta
from decimal import Decimal
from math import ceil
from os.path import realpath

import mysql.connector
import six
from mysql.connector import errorcode  # pylint: disable=C0412
from tqdm import trange

from mysql_to_sqlite3.sqlite_utils import (  # noqa: ignore=I100
    adapt_decimal,
    adapt_timedelta,
    convert_decimal,
    convert_timedelta,
    encode_data_for_sqlite,
)

if six.PY2:
    from .sixeptions import *  # pylint: disable=W0622,W0401,W0614


class MySQLtoSQLite:  # pylint: disable=R0902,R0903
    """Use this class to transfer a MySQL database to SQLite."""

    COLUMN_PATTERN = re.compile(r"^[^(]+")
    COLUMN_LENGTH_PATTERN = re.compile(r"\(\d+\)$")

    def __init__(self, **kwargs):
        """Constructor."""
        if not kwargs.get("mysql_database"):
            raise ValueError("Please provide a MySQL database")

        if not kwargs.get("mysql_user"):
            raise ValueError("Please provide a MySQL user")

        self._mysql_database = str(kwargs.get("mysql_database"))

        self._mysql_tables = (
            tuple(kwargs.get("mysql_tables"))
            if kwargs.get("mysql_tables") is not None
            else tuple()
        )

        self._without_foreign_keys = (
            True
            if len(self._mysql_tables) > 0
            else (kwargs.get("without_foreign_keys") or False)
        )

        self._mysql_user = str(kwargs.get("mysql_user"))

        self._mysql_password = (
            str(kwargs.get("mysql_password")) if kwargs.get("mysql_password") else None
        )

        self._mysql_host = str(kwargs.get("mysql_host") or "localhost")

        self._mysql_port = int(kwargs.get("mysql_port") or 3306)

        self._current_chunk_number = 0
        self._chunk_size = int(kwargs.get("chunk")) if kwargs.get("chunk") else None

        self._sqlite_file = kwargs.get("sqlite_file") or None

        self._buffered = kwargs.get("buffered") or False

        self._vacuum = kwargs.get("vacuum") or False

        self._logger = self._setup_logger(log_file=kwargs.get("log_file") or None)

        sqlite3.register_adapter(Decimal, adapt_decimal)
        sqlite3.register_converter("DECIMAL", convert_decimal)
        sqlite3.register_adapter(timedelta, adapt_timedelta)
        sqlite3.register_converter("TIME", convert_timedelta)

        self._sqlite = sqlite3.connect(
            realpath(self._sqlite_file), detect_types=sqlite3.PARSE_DECLTYPES
        )
        self._sqlite.row_factory = sqlite3.Row

        self._sqlite_cur = self._sqlite.cursor()

        try:
            self._mysql = mysql.connector.connect(
                user=self._mysql_user,
                password=self._mysql_password,
                host=self._mysql_host,
                port=self._mysql_port,
            )
            if not self._mysql.is_connected():
                raise ConnectionError("Unable to connect to MySQL")

            self._mysql_cur = self._mysql.cursor(buffered=self._buffered, raw=True)
            self._mysql_cur_prepared = self._mysql.cursor(prepared=True)
            self._mysql_cur_dict = self._mysql.cursor(
                buffered=self._buffered, dictionary=True,
            )
            try:
                self._mysql.database = self._mysql_database
            except (mysql.connector.Error, Exception) as err:  # pylint: disable=W0703
                if hasattr(err, "errno") and err.errno == errorcode.ER_BAD_DB_ERROR:
                    self._logger.error("MySQL Database does not exist!")
                    raise
                self._logger.error(err)
                raise
        except mysql.connector.Error as err:
            self._logger.error(err)
            raise

    @classmethod
    def _setup_logger(cls, log_file=None):
        formatter = logging.Formatter(
            fmt="%(asctime)s %(levelname)-8s %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
        )
        screen_handler = logging.StreamHandler(stream=sys.stdout)
        screen_handler.setFormatter(formatter)
        logger = logging.getLogger(cls.__name__)
        logger.setLevel(logging.DEBUG)
        logger.addHandler(screen_handler)

        if log_file:
            file_handler = logging.FileHandler(realpath(log_file), mode="w")
            file_handler.setFormatter(formatter)
            logger.addHandler(file_handler)

        return logger

    @classmethod
    def _valid_column_type(cls, column_type):
        return cls.COLUMN_PATTERN.match(column_type.strip())

    @classmethod
    def _column_type_length(cls, column_type):
        suffix = cls.COLUMN_LENGTH_PATTERN.search(column_type)
        if suffix:
            return suffix.group(0)
        return ""

    @classmethod
    def _translate_type_from_mysql_to_sqlite(
        cls, column_type  # pylint: disable=C0330
    ):  # pylint: disable=R0911
        """This could be optimized even further, however is seems adequate."""
        match = cls._valid_column_type(column_type)
        if not match:
            raise ValueError("Invalid column_type!")

        data_type = match.group(0).upper()
        if data_type in {
            "BIGINT",  # pylint: disable=C0330
            "BLOB",  # pylint: disable=C0330
            "BOOLEAN",  # pylint: disable=C0330
            "DATE",  # pylint: disable=C0330
            "DATETIME",  # pylint: disable=C0330
            "DECIMAL",  # pylint: disable=C0330
            "DOUBLE",  # pylint: disable=C0330
            "FLOAT",  # pylint: disable=C0330
            "INTEGER",  # pylint: disable=C0330
            "MEDIUMINT",  # pylint: disable=C0330
            "NUMERIC",  # pylint: disable=C0330
            "REAL",  # pylint: disable=C0330
            "SMALLINT",  # pylint: disable=C0330
            "TIME",  # pylint: disable=C0330
            "TINYINT",  # pylint: disable=C0330
            "YEAR",  # pylint: disable=C0330
        }:
            return data_type
        if data_type in {
            "BIT",  # pylint: disable=C0330
            "BINARY",  # pylint: disable=C0330
            "LONGBLOB",  # pylint: disable=C0330
            "MEDIUMBLOB",  # pylint: disable=C0330
            "TINYBLOB",  # pylint: disable=C0330
            "VARBINARY",  # pylint: disable=C0330
        }:
            return "BLOB"
        if data_type in {"NCHAR", "NVARCHAR", "VARCHAR"}:
            return data_type + cls._column_type_length(column_type)
        if data_type == "CHAR":
            return "CHARACTER" + cls._column_type_length(column_type)
        if data_type == "INT":
            return "INTEGER"
        if data_type in "TIMESTAMP":
            return "DATETIME"
        return "TEXT"

    def _build_create_table_sql(self, table_name):
        sql = 'CREATE TABLE IF NOT EXISTS "{}" ('.format(table_name)
        primary = ""
        indices = ""

        self._mysql_cur_dict.execute("SHOW COLUMNS FROM `{}`".format(table_name))

        for row in self._mysql_cur_dict.fetchall():
            sql += '\n\t"{name}" {type} {notnull},'.format(
                name=row["Field"],
                type=self._translate_type_from_mysql_to_sqlite(row["Type"]),
                notnull="NULL" if row["Null"] == "YES" else "NOT NULL",
            )

        self._mysql_cur_dict.execute(
            """
            SELECT INDEX_NAME AS `name`,
                  IF (NON_UNIQUE = 0 AND INDEX_NAME = 'PRIMARY', 1, 0) AS `primary`,
                  IF (NON_UNIQUE = 0 AND INDEX_NAME <> 'PRIMARY', 1, 0) AS `unique`,
                  GROUP_CONCAT(COLUMN_NAME ORDER BY SEQ_IN_INDEX) AS `columns`
            FROM information_schema.STATISTICS
            WHERE TABLE_SCHEMA = %s
            AND TABLE_NAME = %s
            GROUP BY INDEX_NAME, NON_UNIQUE
            """,
            (self._mysql_database, table_name),
        )
        for index in self._mysql_cur_dict.fetchall():
            if int(index["primary"]) == 1:
                primary += "\n\tPRIMARY KEY ({columns})".format(
                    columns=", ".join(
                        '"{}"'.format(column) for column in index["columns"].split(",")
                    )
                )
            else:
                indices += """CREATE {unique} INDEX "{name}" ON "{table}" ({columns});""".format(  # noqa: ignore=E501 pylint: disable=C0301
                    unique="UNIQUE" if int(index["unique"]) == 1 else "",
                    # combine the index name with the table name in order to
                    # make the index names unique across the database
                    name="{table}_{name}".format(table=table_name, name=index["name"]),
                    table=table_name,
                    columns=", ".join(
                        '"{}"'.format(column) for column in index["columns"].split(",")
                    ),
                )

        sql += primary
        sql = sql.rstrip(", ")

        if not self._without_foreign_keys:
            self._mysql_cur_dict.execute(
                """
                SELECT k.COLUMN_NAME AS `column`,
                       k.REFERENCED_TABLE_NAME AS `ref_table`,
                       k.REFERENCED_COLUMN_NAME AS `ref_column`,
                       c.UPDATE_RULE AS `on_update`,
                       c.DELETE_RULE AS `on_delete`
                FROM information_schema.TABLE_CONSTRAINTS AS i
                LEFT JOIN information_schema.KEY_COLUMN_USAGE AS k
                    ON i.CONSTRAINT_NAME = k.CONSTRAINT_NAME
                LEFT JOIN information_schema.REFERENTIAL_CONSTRAINTS AS c
                    ON c.CONSTRAINT_NAME = i.CONSTRAINT_NAME
                WHERE i.TABLE_SCHEMA = %s
                AND i.TABLE_NAME = %s
                AND i.CONSTRAINT_TYPE = %s
                GROUP BY i.CONSTRAINT_NAME,
                         k.COLUMN_NAME,
                         k.REFERENCED_TABLE_NAME,
                         k.REFERENCED_COLUMN_NAME,
                         c.UPDATE_RULE,
                         c.DELETE_RULE
                """,
                (self._mysql_database, table_name, "FOREIGN KEY"),
            )
            for foreign_key in self._mysql_cur_dict.fetchall():
                sql += """,\n\tFOREIGN KEY("{column}") REFERENCES "{ref_table}" ("{ref_column}") ON UPDATE {on_update} ON DELETE {on_delete}""".format(  # noqa: ignore=E501 pylint: disable=C0301
                    **foreign_key
                )

        sql += "\n);"
        sql += indices
        return sql

    def _create_table(self, table_name, attempting_reconnect=False):
        try:
            if attempting_reconnect:
                self._mysql.reconnect()
            self._sqlite_cur.executescript(self._build_create_table_sql(table_name))
            self._sqlite.commit()
        except mysql.connector.Error as err:
            if err.errno == errorcode.CR_SERVER_LOST:
                if not attempting_reconnect:
                    self._logger.warning(
                        "Connection to MySQL server lost." "\nAttempting to reconnect."
                    )
                    self._create_table(table_name, True)
                else:
                    self._logger.warning(
                        "Connection to MySQL server lost."
                        "\nReconnection attempt aborted."
                    )
                    raise
            self._logger.error(
                "MySQL failed reading table definition from table %s: %s",
                table_name,
                err,
            )
            raise
        except sqlite3.Error as err:
            self._logger.error("SQLite failed creating table %s: %s", table_name, err)
            raise

    def _transfer_table_data(  # pylint: disable=C0330
        self, table_name, sql, total_records=0, attempting_reconnect=False
    ):
        if attempting_reconnect:
            self._mysql.reconnect()
        try:
            if self._chunk_size is not None and self._chunk_size > 0:
                for chunk in trange(
                    self._current_chunk_number,  # pylint: disable=C0330
                    int(
                        ceil(total_records / self._chunk_size)
                    ),  # pylint: disable=C0330
                ):
                    self._current_chunk_number = chunk
                    self._sqlite_cur.executemany(
                        sql,
                        (
                            tuple(
                                encode_data_for_sqlite(col) if col is not None else None
                                for col in row
                            )
                            for row in self._mysql_cur.fetchmany(self._chunk_size)
                        ),
                    )
            else:
                self._sqlite_cur.executemany(
                    sql,
                    (
                        tuple(
                            encode_data_for_sqlite(col) if col is not None else None
                            for col in row
                        )
                        for row in self._mysql_cur.fetchall()
                    ),
                )
            self._sqlite.commit()
        except mysql.connector.Error as err:
            if err.errno == errorcode.CR_SERVER_LOST:
                if not attempting_reconnect:
                    self._logger.warning(
                        "Connection to MySQL server lost." "\nAttempting to reconnect."
                    )
                    self._transfer_table_data(
                        table_name=table_name,
                        sql=sql,
                        total_records=total_records,
                        attempting_reconnect=True,
                    )
                else:
                    self._logger.warning(
                        "Connection to MySQL server lost."
                        "\nReconnection attempt aborted."
                    )
                    raise
            self._logger.error(
                "MySQL transfer failed reading table data from table %s: %s",
                table_name,
                err,
            )
            raise
        except sqlite3.Error as err:
            self._logger.error(
                "SQLite transfer failed inserting data into table %s: %s",
                table_name,
                err,
            )
            raise

    def transfer(self):
        """The primary and only method with which we transfer all the data."""
        if len(self._mysql_tables) > 0:
            # transfer only specific tables
            # pylint: disable=C0330
            self._mysql_cur_prepared.execute(
                """
                SELECT TABLE_NAME
                FROM information_schema.TABLES
                WHERE TABLE_SCHEMA = SCHEMA()
                AND TABLE_NAME IN ({placeholders})
            """.format(
                    placeholders=("%s, " * len(self._mysql_tables)).rstrip(" ,")
                ),
                self._mysql_tables,
            )
            tables = (row[0] for row in self._mysql_cur_prepared.fetchall())
        else:
            # transfer all tables
            self._mysql_cur.execute(
                """
                SELECT TABLE_NAME
                FROM information_schema.TABLES
                WHERE TABLE_SCHEMA = SCHEMA()
            """
            )
            tables = (row[0].decode() for row in self._mysql_cur.fetchall())

        try:
            # turn off foreign key checking in SQLite while transferring data
            self._sqlite_cur.execute("PRAGMA foreign_keys=OFF")

            for table_name in tables:
                # reset the chunk
                self._current_chunk_number = 0

                # create the table
                self._create_table(table_name)

                # get the size of the data
                self._mysql_cur_dict.execute(
                    "SELECT COUNT(*) AS `total_records` FROM `{}`".format(table_name)
                )
                total_records = int(self._mysql_cur_dict.fetchone()["total_records"])

                # only continue if there is anything to transfer
                if total_records > 0:
                    # populate it
                    self._logger.info("Transferring table %s", table_name)
                    self._mysql_cur.execute("SELECT * FROM `{}`".format(table_name))
                    columns = [column[0] for column in self._mysql_cur.description]
                    # build the SQL string
                    sql = 'INSERT OR IGNORE INTO "{table}" ({fields}) VALUES ({placeholders})'.format(  # noqa: ignore=E501 pylint: disable=C0301
                        table=table_name,
                        fields=('"{}", ' * len(columns)).rstrip(" ,").format(*columns),
                        placeholders=("?, " * len(columns)).rstrip(" ,"),
                    )
                    self._transfer_table_data(
                        table_name=table_name, sql=sql, total_records=total_records
                    )
        except Exception:  # pylint: disable=W0706
            raise
        finally:
            # re-enable foreign key checking once done transferring
            self._sqlite_cur.execute("PRAGMA foreign_keys=ON")

        if self._vacuum:
            self._logger.info(
                "Vacuuming created SQLite database file.\nThis might take a while."
            )
            self._sqlite_cur.execute("VACUUM")

        self._logger.info("Done!")