import itertools
import os.path
import pprint
import sqlite3
import time
from collections import defaultdict
from multiprocessing.dummy import Pool

import boto3
from boto3.resources.collection import CollectionManager
from botocore.exceptions import NoCredentialsError

from aq import logger, util, sqlite_util
from aq.errors import QueryError

DEFAULT_REGION = 'us_east_1'

LOGGER = logger.get_logger()


class BotoSqliteEngine(object):
    def __init__(self, options=None):
        self.options = options if options else {}
        self.debug = options.get('--debug', False)

        self.profile = options.get('--profile', None)
        self.region = options.get('--region', None)
        self.table_cache_ttl = int(options.get('--table-cache-ttl', 300))
        self.last_refresh_time = defaultdict(int)

        self.boto3_session = boto3.Session(profile_name=self.profile)
        # dash (-) is not allowed in database name so we use underscore (_) instead in region name
        # throughout this module region name will *always* use underscore
        if self.region:
            self.default_region = self.region.replace('-', '_')
        elif self.boto3_session.region_name:
            self.default_region = self.boto3_session.region_name.replace('-', '_')
        else:
            self.default_region = DEFAULT_REGION

        self.boto3_session = boto3.Session(profile_name=self.profile, region_name=self.default_region.replace('_', '-'))
        self.db = self.init_db()
        # attach the default region too
        self.attach_region(self.default_region)

    def init_db(self):
        util.ensure_data_dir_exists()
        db_path = '~/.aq/{0}.db'.format(self.default_region)
        absolute_path = os.path.expanduser(db_path)
        return sqlite_util.connect(absolute_path)

    def execute(self, query, metadata):
        LOGGER.info('Executing query: %s', query)
        self.load_tables(query, metadata)
        try:
            cursor = self.db.execute(query)
        except sqlite3.OperationalError as e:
            raise QueryError(str(e))
        columns = [d[0] for d in cursor.description]
        rows = cursor.fetchall()
        return columns, rows

    def load_tables(self, query, meta):
        """
        Load necessary resources tables into db to execute given query.
        """
        try:
            for table in meta.tables:
                self.load_table(table)
        except NoCredentialsError:
            help_link = 'http://boto3.readthedocs.io/en/latest/guide/configuration.html'
            raise QueryError('Unable to locate AWS credential. '
                             'Please see {0} on how to configure AWS credential.'.format(help_link))

    def load_table(self, table):
        """
        Load resources as specified by given table into our db.
        """
        region = table.database if table.database else self.default_region
        resource_name, collection_name = table.table.split('_', 1)
        # we use underscore "_" instead of dash "-" for region name but boto3 need dash
        boto_region_name = region.replace('_', '-')
        resource = self.boto3_session.resource(resource_name, region_name=boto_region_name)
        if not hasattr(resource, collection_name):
            raise QueryError(
                'Unknown collection <{0}> of resource <{1}>'.format(collection_name, resource_name))

        self.attach_region(region)
        self.refresh_table(region, table.table, resource, getattr(resource, collection_name))

    def attach_region(self, region):
        if not self.is_attached_region(region):
            LOGGER.info('Attaching new database for region: %s', region)
            region_db_file_path = '~/.aq/{0}.db'.format(region)
            absolute_path = os.path.expanduser(region_db_file_path)
            self.db.execute('ATTACH DATABASE ? AS ?', (absolute_path, region))

    def is_attached_region(self, region):
        databases = self.db.execute('PRAGMA database_list')
        db_names = (db[1] for db in databases)
        return region in db_names

    def refresh_table(self, schema_name, table_name, resource, collection):
        if not self.is_fresh_enough(schema_name, table_name):
            LOGGER.info('Refreshing table: %s.%s', schema_name, table_name)
            columns = get_columns_list(resource, collection)
            LOGGER.info('Columns list: %s', columns)
            with self.db:
                sqlite_util.create_table(self.db, schema_name, table_name, columns)
                items = collection.all()
                # special treatment for tags field
                items = [convert_tags_to_dict(item) for item in items]
                sqlite_util.insert_all(self.db, schema_name, table_name, columns, items)
                self.last_refresh_time[(schema_name, table_name)] = time.time()

    def is_fresh_enough(self, schema_name, table_name):
        last_refresh = self.last_refresh_time[(schema_name, table_name)]
        age = time.time() - last_refresh
        return age < self.table_cache_ttl

    @property
    def available_schemas(self):
        # we want to return all regions if possible so ec2 is a good enough guess
        regions = self.boto3_session.get_available_regions(service_name='ec2')
        return [r.replace('-', '_') for r in regions]

    @property
    def available_tables(self):
        resources = self.boto3_session.get_available_resources()
        tables = Pool(processes=len(resources)).map(self._get_table_names_for_resource, resources)
        return itertools.chain.from_iterable(tables)

    def _get_table_names_for_resource(self, resource_name):
        resource = self.boto3_session.resource(resource_name)
        for attr in dir(resource):
            if isinstance(getattr(resource, attr), CollectionManager):
                yield '{0}_{1}'.format(resource_name, attr)


class ObjectProxy(object):
    def __init__(self, source, **replaced_fields):
        self.source = source
        self.replaced_fields = replaced_fields

    def __getattr__(self, item):
        if item in self.replaced_fields:
            return self.replaced_fields[item]
        return getattr(self.source, item)


def convert_tags_to_dict(item):
    """
    Convert AWS inconvenient tags model of a list of {"Key": <key>, "Value": <value>} pairs
    to a dict of {<key>: <value>} for easier querying.

    This returns a proxied object over given item to return a different tags format as the tags
    attribute is read-only and we cannot modify it directly.
    """
    if hasattr(item, 'tags'):
        tags = item.tags
        if isinstance(tags, list):
            tags_dict = {}
            for kv_dict in tags:
                if isinstance(kv_dict, dict) and 'Key' in kv_dict and 'Value' in kv_dict:
                    tags_dict[kv_dict['Key']] = kv_dict['Value']
            return ObjectProxy(item, tags=tags_dict)
    return item


def get_resource_model_attributes(resource, collection):
    service_model = resource.meta.client.meta.service_model
    resource_model = get_resource_model(collection)
    shape_name = resource_model.shape
    shape = service_model.shape_for(shape_name)
    return resource_model.get_attributes(shape)


def get_columns_list(resource, collection):
    resource_model = get_resource_model(collection)
    LOGGER.debug('Resource model: %s', resource_model)

    identifiers = sorted(i.name for i in resource_model.identifiers)
    LOGGER.debug('Model identifiers: %s', identifiers)

    attributes = get_resource_model_attributes(resource, collection)
    LOGGER.debug('Model attributes: %s', pprint.pformat(attributes))

    return list(itertools.chain(identifiers, attributes))


def get_resource_model(collection):
    return collection._model.resource.model