from abc import ABC, abstractmethod
from unittest import TestCase, mock

from argparse import Namespace
import pymysql
import psycopg2
import logging
import pytest

from piicatcher.explorer.databases import MySQLExplorer, PostgreSQLExplorer, OracleExplorer, \
    RelDbExplorer
from piicatcher.explorer.sqlite import SqliteExplorer
from piicatcher.explorer.metadata import Schema, Table, Column
from piicatcher.piitypes import PiiTypes

logging.basicConfig(level=logging.DEBUG)

pii_data_script = """
create table no_pii(a text, b text);
insert into no_pii values ('abc', 'def');
insert into no_pii values ('xsfr', 'asawe');

create table partial_pii(a text, b text);
insert into partial_pii values ('917-908-2234', 'plkj');
insert into partial_pii values ('215-099-2234', 'sfrf');

create table full_pii(name text, location text);
insert into full_pii values ('Jonathan Smith', 'Virginia');
insert into full_pii values ('Chase Ryan', 'Chennai');

"""


class CommonExplorerTestCases:
    class CommonExplorerTests(TestCase, ABC):
        explorer = None

        @abstractmethod
        def get_test_schema(self):
            pass

        def test_columns(self):
            names = [col.get_name() for col in self.explorer.get_columns(self.get_test_schema(), "no_pii")]
            self.assertEqual(['a', 'b'], names)

        def test_tables(self):
            names = [tbl.get_name() for tbl in self.explorer.get_tables(self.get_test_schema())]
            self.assertEqual(sorted(['no_pii', 'partial_pii', 'full_pii']), sorted(names))

        def test_scan_dbexplorer(self):
            self.explorer.scan()
            for schema in self.explorer.get_schemas():
                if schema.get_name() == self.get_test_schema():
                    self.assertTrue(schema.has_pii())


char_data_types = """
create table char_columns(ch char, chn char(10), vchn varchar(10));
create table no_char_columns(i int, d float, t time, dt date, ts timestamp);
create table some_char_columns(vchn varchar(10), txt text, i int)
"""


class CommonDataTypeTestCases:
    class CommonDataTypeTests(TestCase, ABC):
        explorer = None

        @abstractmethod
        def get_test_schema(self):
            pass

        def test_char_columns(self):
            names = [col.get_name() for col in self.explorer.get_columns(self.get_test_schema(), "char_columns")]
            self.assertEqual(['ch', 'chn', 'vchn'], names)

        def test_some_char_columns(self):
            names = [col.get_name() for col in self.explorer.get_columns(self.get_test_schema(), "some_char_columns")]
            self.assertEqual(['txt', 'vchn'], names)

        def test_tables(self):
            names = [tbl.get_name() for tbl in self.explorer.get_tables(self.get_test_schema())]
            self.assertEqual(sorted(['char_columns', 'some_char_columns']), sorted(names))


@pytest.mark.usefixtures("create_tables")
class MySQLExplorerTest(CommonExplorerTestCases.CommonExplorerTests):
    pii_db_drop = """
        DROP TABLE full_pii;
        DROP TABLE partial_pii;
        DROP TABLE no_pii;
    """

    @staticmethod
    def execute_script(cursor, script):
        for query in script.split(';'):
            if len(query.strip()) > 0:
                cursor.execute(query)

    @pytest.fixture(scope="class")
    def create_tables(self, request):
        self.conn = pymysql.connect(host="127.0.0.1",
                                    user="piiuser",
                                    password="p11secret",
                                    database="piidb",
                                    )

        with self.conn.cursor() as cursor:
            self.execute_script(cursor, pii_data_script)
            cursor.execute("commit")
            cursor.close()

        def drop_tables():
            with self.conn.cursor() as cursor:
                self.execute_script(cursor, self.pii_db_drop)
                logging.info("Executed drop script")
                cursor.close()
            self.conn.close()

        request.addfinalizer(drop_tables)

    def setUp(self):
        self.explorer = MySQLExplorer(Namespace(
            host="127.0.0.1",
            user="piiuser",
            password="p11secret",
            database="piidb",
            include_schema=(),
            exclude_schema=(),
            include_table=(),
            exclude_table=(),
            catalog=None
        ))

    def tearDown(self):
        self.explorer.get_connection().close()

    def test_schema(self):
        names = [sch.get_name() for sch in self.explorer.get_schemas()]
        self.assertEqual(['piidb'], names)
        return "piidb"

    def get_test_schema(self):
        return "piidb"


@pytest.mark.usefixtures("create_tables")
class MySQLDataTypeTest(CommonDataTypeTestCases.CommonDataTypeTests):
    char_db_drop = """
        DROP TABLE char_columns;
        DROP TABLE no_char_columns;
        DROP TABLE some_char_columns;
    """

    @staticmethod
    def execute_script(cursor, script):
        for query in script.split(';'):
            if len(query.strip()) > 0:
                cursor.execute(query)

    @pytest.fixture(scope="class")
    def create_tables(self, request):
        self.conn = pymysql.connect(host="127.0.0.1",
                                    user="piiuser",
                                    password="p11secret",
                                    database="piidb"
                                    )

        with self.conn.cursor() as cursor:
            self.execute_script(cursor, char_data_types)
            cursor.execute("commit")
            cursor.close()

        def drop_tables():
            with self.conn.cursor() as drop_cursor:
                self.execute_script(drop_cursor, self.char_db_drop)
                logging.info("Executed drop script")
                drop_cursor.close()
            self.conn.close()

        request.addfinalizer(drop_tables)

    def setUp(self):
        self.explorer = MySQLExplorer(Namespace(
            host="127.0.0.1",
            user="piiuser",
            password="p11secret",
            database="piidb",
            include_schema=(),
            exclude_schema=(),
            include_table=(),
            exclude_table=(),
            catalog=None
        ))

    def tearDown(self):
        self.explorer.get_connection().close()

    def get_test_schema(self):
        return "piidb"


@pytest.mark.usefixtures("create_tables")
class PostgresDataTypeTest(CommonDataTypeTestCases.CommonDataTypeTests):
    char_db_drop = """
        DROP TABLE char_columns;
        DROP TABLE no_char_columns;
        DROP TABLE some_char_columns;
    """

    @staticmethod
    def execute_script(cursor, script):
        for query in script.split(';'):
            if len(query.strip()) > 0:
                cursor.execute(query)

    @pytest.fixture(scope="class")
    def create_tables(self, request):
        self.conn = psycopg2.connect(host="127.0.0.1",
                                     user="piiuser",
                                     password="p11secret",
                                     database="piidb")

        self.conn.autocommit = True

        with self.conn.cursor() as cursor:
            self.execute_script(cursor, char_data_types)
            cursor.close()

        def drop_tables():
            with self.conn.cursor() as d_cursor:
                d_cursor.execute(self.char_db_drop)
                logging.info("Executed drop script")
                d_cursor.close()
            self.conn.close()

        request.addfinalizer(drop_tables)

    def setUp(self):
        self.explorer = PostgreSQLExplorer(Namespace(
            host="127.0.0.1",
            user="piiuser",
            password="p11secret",
            database="piidb",
            include_schema=(),
            exclude_schema=(),
            include_table=(),
            exclude_table=(),
            catalog=None
        ))

    def tearDown(self):
        self.explorer.get_connection().close()

    def get_test_schema(self):
        return "public"


@pytest.mark.usefixtures("create_tables")
class PostgresExplorerTest(CommonExplorerTestCases.CommonExplorerTests):
    pii_db_drop = """
        DROP TABLE full_pii;
        DROP TABLE partial_pii;
        DROP TABLE no_pii;
        DROP SCHEMA company cascade;
    """

    second_schema = """
        CREATE SCHEMA company;
        CREATE TABLE company.employees(name varchar, designation varchar);
        CREATE TABLE company.departments(name varchar, manager varchar);
    """
    @staticmethod
    def execute_script(cursor, script):
        for query in script.split(';'):
            if len(query.strip()) > 0:
                cursor.execute(query)

    @pytest.fixture(scope="class")
    def create_tables(self, request):
        self.conn = psycopg2.connect(host="127.0.0.1",
                                     user="piiuser",
                                     password="p11secret",
                                     database="piidb")

        self.conn.autocommit = True

        with self.conn.cursor() as cursor:
            self.execute_script(cursor, pii_data_script)
            self.execute_script(cursor, self.second_schema)
            cursor.close()

        def drop_tables():
            with self.conn.cursor() as d_cursor:
                d_cursor.execute(self.pii_db_drop)
                logging.info("Executed drop script")
                d_cursor.close()
            self.conn.close()

        request.addfinalizer(drop_tables)

    def setUp(self):
        self.explorer = PostgreSQLExplorer(Namespace(
            host="127.0.0.1",
            user="piiuser",
            password="p11secret",
            database="piidb",
            include_schema=(),
            exclude_schema=(),
            include_table=(),
            exclude_table=(),
            catalog=None
        ))

    def tearDown(self):
        self.explorer.get_connection().close()

    def test_schema(self):
        names = [sch.get_name() for sch in self.explorer.get_schemas()]
        self.assertCountEqual(['public', 'company'], names)

    def get_test_schema(self):
        return "public"


class SelectQueryTest(TestCase):
    def setUp(self):
        col1 = Column('c1')
        col2 = Column('c2')
        col2._pii = [PiiTypes.LOCATION]

        self.schema = Schema('testSchema')

        table = Table(self.schema, 't1')
        table.add_child(col1)
        table.add_child(col2)

        self.schema.add_child(table)

    def test_oracle(self):
        self.assertEqual("select c1,c2 from t1",
                         OracleExplorer._get_select_query(self.schema,
                                                          self.schema.get_children()[0],
                                                          self.schema.get_children()[0].get_children()))

    def test_sqlite(self):
        self.assertEqual("select c1,c2 from t1",
                         SqliteExplorer._get_select_query(self.schema,
                                                          self.schema.get_children()[0],
                                                          self.schema.get_children()[0].get_children()))

    def test_postgres(self):
        self.assertEqual("select c1,c2 from testSchema.t1",
                         PostgreSQLExplorer._get_select_query(self.schema,
                                                              self.schema.get_children()[0],
                                                              self.schema.get_children()[0].get_children()))

    def test_mysql(self):
        self.assertEqual("select c1,c2 from testSchema.t1",
                         MySQLExplorer._get_select_query(self.schema,
                                                         self.schema.get_children()[0],
                                                         self.schema.get_children()[0].get_children()))


class TestDispatcher(TestCase):

    def test_mysql_dispatch(self):
        with mock.patch('piicatcher.explorer.databases.MySQLExplorer.scan', autospec=True) \
                as mock_scan_method:
            with mock.patch('piicatcher.explorer.databases.MySQLExplorer.get_tabular',
                            autospec=True) as mock_tabular_method:
                with mock.patch('piicatcher.explorer.explorer.tableprint', autospec=True) \
                        as MockTablePrint:
                    RelDbExplorer.dispatch(Namespace(host='connection',
                                                     port=None,
                                                     list_all=None,
                                                     connection_type='mysql',
                                                     scan_type='deep',
                                                     catalog={
                                                         'format': 'ascii_table'
                                                     },
                                                     user='user',
                                                     include_schema=(),
                                                     exclude_schema=(),
                                                     include_table=(),
                                                     exclude_table=(),
                                                     password='pass'))
                    mock_scan_method.assert_called_once()
                    mock_tabular_method.assert_called_once()
                    MockTablePrint.table.assert_called_once()

    def test_postgres_dispatch(self):
        with mock.patch('piicatcher.explorer.databases.PostgreSQLExplorer.scan', autospec=True) \
                as mock_scan_method:
            with mock.patch('piicatcher.explorer.databases.PostgreSQLExplorer.get_tabular',
                            autospec=True) as mock_tabular_method:
                with mock.patch('piicatcher.explorer.explorer.tableprint', autospec=True) \
                        as MockTablePrint:
                    RelDbExplorer.dispatch(Namespace(host='connection',
                                                     port=None,
                                                     list_all=None,
                                                     connection_type='postgres',
                                                     database='public',
                                                     scan_type=None,
                                                     catalog={
                                                         'format': 'ascii_table'
                                                     },
                                                     include_schema=(),
                                                     exclude_schema=(),
                                                     include_table=(),
                                                     exclude_table=(),
                                                     user='user',
                                                     password='pass'))
                    mock_scan_method.assert_called_once()
                    mock_tabular_method.assert_called_once()
                    MockTablePrint.table.assert_called_once()

    def test_mysql_shallow_scan(self):
        with mock.patch('piicatcher.explorer.databases.MySQLExplorer.shallow_scan',
                        autospec=True) as mock_shallow_scan_method:
            with mock.patch('piicatcher.explorer.databases.MySQLExplorer.get_tabular',
                            autospec=True) as mock_tabular_method:
                with mock.patch('piicatcher.explorer.explorer.tableprint', autospec=True) \
                        as MockTablePrint:
                    RelDbExplorer.dispatch(Namespace(host='connection',
                                                     port=None,
                                                     list_all=None,
                                                     connection_type='mysql',
                                                     catalog={
                                                         'format': 'ascii_table'
                                                     },
                                                     include_schema=(),
                                                     exclude_schema=(),
                                                     include_table=(),
                                                     exclude_table=(),
                                                     user='user',
                                                     password='pass',
                                                     scan_type="shallow"))
                    mock_shallow_scan_method.assert_called_once()
                    mock_tabular_method.assert_called_once()
                    MockTablePrint.table.assert_called_once()