import os
import sqlite3

from django.db import connections
from django.test import TestCase
from django_dynamic_fixture import G

from dynamic_db_router import DynamicDbRouter, in_database
from .models import TestModel


class TestInDataBaseContextManager(TestCase):
    multi_db = True

    def test_string_identifier(self):
        G(TestModel, name='Arnold')
        with in_database('default'):
            count = TestModel.objects.count()
        expected = 1
        self.assertEqual(count, expected)

    def test_readonly_connection_writes_to_default(self):
        with in_database('test'):
            G(TestModel, name='Arnold')
            test_count = TestModel.objects.count()
        default_count = TestModel.objects.count()
        self.assertEqual(test_count, 0)
        self.assertEqual(default_count, 1)

    def test_write_connection_writes_to_test(self):
        with in_database('test', write=True):
            G(TestModel, name='Arnold')
            test_count = TestModel.objects.count()
        default_count = TestModel.objects.count()
        self.assertEqual(test_count, 1)
        self.assertEqual(default_count, 0)

    def test_write_only_connection_reads_from_default(self):
        with in_database('test', read=False, write=True):
            G(TestModel, name='Arnold')
            test_count = TestModel.objects.count()
        default_count = TestModel.objects.count()
        self.assertEqual(test_count, 0)
        self.assertEqual(default_count, 0)

    def test_recursive_context_manager(self):
        with in_database('test', write=True):
            G(TestModel, name='Arnold')
            with in_database('default', write=True):
                pass
            test_count = TestModel.objects.count()
        self.assertEqual(test_count, 1)

    def test_bad_input_value(self):
        with self.assertRaises(ValueError):
            with in_database(2):
                pass


class TestDynamicDatabaseConnection(TestCase):
    def setUp(self):
        # Create a sqlite database with the models that django will
        # expect.
        PROJECT_DIR = os.path.abspath(os.path.dirname(__file__))
        self.db_filename = os.path.join(PROJECT_DIR, 'dynamic_test_router.db')
        conn = sqlite3.connect(self.db_filename)
        cur = conn.cursor()
        create_table_query = (
            'CREATE TABLE tests_testmodel('
            '    id PRIMARY KEY, name varchar(32));'
        )
        cur.execute(create_table_query)
        conn.commit()
        conn.close()

        # The database configuration to use with in_database
        self.test_db_config = {
            'ENGINE': 'django.db.backends.sqlite3',
            'NAME': self.db_filename,
        }

    def tearDown(self):
        os.remove(self.db_filename)

    def test_create_db_object(self):
        with in_database(self.test_db_config, write=True):
            G(TestModel, name='Arnold')
            G(TestModel, name='Sue')
            count = TestModel.objects.count()
        expected = 2
        self.assertEqual(count, expected)

    def test_where_db_objects_come_from(self):
        with in_database(self.test_db_config, write=True) as x:
            G(TestModel, name='Sue')
            database_name = TestModel.objects.get(name='Sue')._state.db
        expected_database_name = x.unique_db_id
        self.assertEqual(database_name, expected_database_name)

    def test_cleans_up(self):
        starting_connections = len(connections.databases)
        with in_database(self.test_db_config, write=True):
            G(TestModel, name='Sue')
        ending_connections = len(connections.databases)
        self.assertEqual(starting_connections, ending_connections)


class TestInDatabaseDecorator(TestCase):
    def test_decorator_matches_context_manager(self):
        @in_database('test')
        def test_db_record_count():
            return TestModel.objects.count()

        with in_database('test'):
            G(TestModel, name='Michael Bluth')
            context_count = TestModel.objects.count()
        decorator_count = test_db_record_count()
        self.assertEqual(context_count, decorator_count)


class TestDynamicDbRouterDefaults(TestCase):
    def test_db_for_read(self):
        router = DynamicDbRouter()
        db_for_read = router.db_for_read(None)
        self.assertIn(db_for_read, ['default', None])

    def test_db_for_wrte(self):
        router = DynamicDbRouter()
        db_for_write = router.db_for_write(None)
        self.assertIn(db_for_write, ['default', None])

    def test_allow_relation(self):
        router = DynamicDbRouter()
        allow_relation = router.allow_relation(None, None)
        self.assertEqual(allow_relation, True)

    def test_allow_syncdb(self):
        router = DynamicDbRouter()
        allow_syncdb = router.allow_syncdb(None, None)
        self.assertEqual(allow_syncdb, None)

    def test_allow_migrate(self):
        router = DynamicDbRouter()
        allow_migrate = router.allow_migrate(None, None)
        self.assertEqual(allow_migrate, None)