""" Cache module that implements the SSM caching wrapper """
from __future__ import absolute_import, print_function

from datetime import datetime, timedelta
from functools import wraps
import six

from ssm_cache.filters import SSMFilter

class InvalidParameterError(Exception):
    """ Raised when something's wrong with the provided param name """

class InvalidPathError(Exception):
    """ Raised when a given path is not properly structured """

class InvalidVersionError(Exception):
    """ Raised when something's wrong with the provided param version """


class Refreshable(object):
    """ Abstract class for refreshable objects (with max-age) """

    _ssm_client = None

    @classmethod
    def set_ssm_client(cls, client):
        """Override the default boto3 SSM client with your own."""
        required_methods = ('get_parameters', 'get_parameters_by_path')
        for method in required_methods:
            if not hasattr(client, method):
                raise TypeError('client must have a %s method' % method)
        cls._ssm_client = client

    @classmethod
    def _get_ssm_client(cls):
        if cls._ssm_client is None:
            import boto3
            cls._ssm_client = boto3.client('ssm')
        return cls._ssm_client

    def __init__(self, max_age):
        self._last_refresh_time = None
        self._max_age = max_age
        self._max_age_delta = timedelta(seconds=max_age or 0)

    def _refresh(self):
        raise NotImplementedError

    def _should_refresh(self):
        # never force refresh if no max_age is configured
        if not self._max_age:
            return False
        # always force refresh if values were never fetched
        if not self._last_refresh_time:
            return True
        # force refresh only if max_age seconds have expired
        return datetime.utcnow() > self._last_refresh_time + self._max_age_delta

    def _update_refresh_time(self, keep_oldest_value=False):
        """
            Update internal reference with current time.
            Optionally, keep the oldest available reference
            (used by groups with multiple fetch operations at potentially different times)
        """
        now = datetime.utcnow()
        if keep_oldest_value and self._last_refresh_time:
            self._last_refresh_time = min(now, self._last_refresh_time)
        else:
            self._last_refresh_time = now

    def refresh(self):
        """ Updates the value(s) of this refreshable """
        self._refresh()
        # keep track of update date for max_age checks
        self._update_refresh_time()

    @staticmethod
    def _parse_value(param_value, param_type):
        if param_type == 'StringList':
            return param_value.split(',')
        return param_value

    @classmethod
    def _get_parameters(cls, names, with_decryption):
        items = {}
        invalid_names = []
        for name_batch in _batch(names, 10): # can only get 10 parameters at a time
            response = SSMParameter._get_ssm_client().get_parameters(
                Names=list(name_batch),
                WithDecryption=with_decryption,
            )
            invalid_names.extend(response['InvalidParameters'])
            for item in response['Parameters']:
                item['Value'] = cls._parse_value(item['Value'], item['Type'])
                items[item['Name']] = item

        return items, invalid_names

    @classmethod
    def _get_parameters_by_path(cls, with_decryption, path, recursive=True, filters=None):
        """ Return all the parameters under the given path """
        items = {}

        # boto3 paginators doc: http://boto3.readthedocs.io/en/latest/guide/paginators.html
        client = SSMParameter._get_ssm_client()
        has_builtin_paginator = hasattr(client, 'get_paginator')

        def get_pages():
            """ Small utility to implement optional pagination (if native boto3 client) """
            if has_builtin_paginator:
                method = client.get_paginator('get_parameters_by_path').paginate
            else:
                method = client.get_parameters_by_path

            def serialize_filter(filter_obj):
                """ Utility function for serialization """
                if isinstance(filter_obj, SSMFilter):
                    return filter_obj.to_dict()
                return filter_obj

            # result will be a list of pages if built-in pagination
            # otherwise a single "page" is expected
            result = method(
                Path=path,
                Recursive=recursive,
                WithDecryption=with_decryption,
                ParameterFilters=[
                    serialize_filter(filter_obj)
                    for filter_obj in (filters or [])
                ],
            )

            return result if has_builtin_paginator else [result]

        for page in get_pages():
            for item in page['Parameters']:
                item['Value'] = cls._parse_value(item['Value'], item['Type'])
                items[item['Name']] = item

        return items


    def refresh_on_error(
            self,
            error_class=Exception,
            error_callback=None,
            retry_argument='is_retry'
        ):
        """ Decorator to handle errors and retries """
        if error_callback and not callable(error_callback):
            raise TypeError("error_callback must be callable")
        def true_decorator(func):
            """ Actual func wrapper """
            @wraps(func)
            def wrapped(*args, **kwargs):
                """ Actual error/retry handling """
                try:
                    return func(*args, **kwargs)
                except error_class:  # pylint: disable=broad-except
                    self.refresh()
                    if error_callback:
                        error_callback()
                    if retry_argument:
                        kwargs[retry_argument] = True
                    return func(*args, **kwargs)
            return wrapped
        return true_decorator

class SSMParameterGroup(Refreshable):
    """ Concrete class that wraps multiple SSM Parameters """

    def __init__(self, max_age=None, with_decryption=True, base_path=""):
        super(SSMParameterGroup, self).__init__(max_age)

        self._with_decryption = with_decryption
        self._parameters = {}
        self._base_path = base_path or ""
        self._validate_path(base_path)  # may raise

    @staticmethod
    def _validate_path(path):
        if path and not path.startswith("/"):
            raise InvalidPathError("Invalid path: %s (should start with a slash)" % path)

    def parameter(self, path, add_prefix=True):
        """ Create a new SSMParameter by name/path (or retrieve an existing one) """
        if path in self._parameters:
            return self._parameters[path]
        if self._base_path and add_prefix:
            # validate path only if base path is used (otherwise it's just a root name)
            self._validate_path(path)  # may raise
            path = "%s%s" % (self._base_path, path)
        parameter = SSMParameter(path)
        parameter._group = self  # pylint: disable=protected-access
        self._parameters[path] = parameter
        return parameter

    def parameters(self, path, recursive=True, filters=None):
        """ Create new SSMParameter objects by path prefix """
        self._validate_path(path)  # may raise
        if self._base_path:
            path = "%s%s" % (self._base_path, path)
        items = self._get_parameters_by_path(
            with_decryption=self._with_decryption,
            path=path,
            recursive=recursive,
            filters=filters,
        )

        # keep track of update date for max_age checks
        # if a previous call to `parameters` was made, keep that time reference for caching
        self._update_refresh_time(keep_oldest_value=True)

        parameters = []
        # create new parameters and set values
        for name, item in six.iteritems(items):
            parameter = self.parameter(name, add_prefix=False)
            parameter._value = item['Value']  # pylint: disable=protected-access
            parameter._version = item['Version']  # pylint: disable=protected-access
            parameters.append(parameter)
        return parameters

    def secret(self, name):
        """ Create a new SecretsManagerParameter by name (or retrieve an existing one) """
        if name in self._parameters:
            return self._parameters[name]
        parameter = SecretsManagerParameter(name)
        parameter._group = self  # pylint: disable=protected-access
        self._parameters[name] = parameter
        return parameter

    def _refresh(self):
        # pylint: disable=protected-access
        names = [
            param.full_name
            for param in self.get_loaded_parameters()
        ]
        items, invalid_names = self._get_parameters(names, self._with_decryption)
        if invalid_names:
            raise InvalidParameterError(",".join(invalid_names))
        for parameter in self.get_loaded_parameters():
            if parameter.name not in items:
                raise InvalidParameterError(parameter.name)
            parameter._value = items[parameter.name]['Value']
            parameter._version = items[parameter.name]['Version']

    def get_loaded_parameters(self):
        """ Return a list of SSMParameter objects """
        return six.itervalues(self._parameters)

    def __len__(self):
        return len(self._parameters)

class SSMParameter(Refreshable):
    """ Concrete class for an individual SSM Parameter """

    def __init__(self, param_name, max_age=None, with_decryption=True):
        super(SSMParameter, self).__init__(max_age)
        if not param_name:
            raise ValueError("Must specify name")

        self._name, self._version, self._is_pinned_version = self._parse_version(param_name)

        self._value = None
        self._with_decryption = with_decryption
        self._group = None

    @staticmethod
    def _parse_version(param_name):
        """ Extracts version from full name, if provided """

        name, version, is_pinned_version = param_name, None, False

        if ":" in param_name:
            name, version = param_name.split(':')

            if version.isdigit() and int(version) > 0:
                version = int(version)
                is_pinned_version = True
            else:
                raise InvalidVersionError("Invalid version: %s" % version)

        return name, version, is_pinned_version

    def _should_refresh(self):
        if self._group:
            return self._group._should_refresh()  # pylint: disable=protected-access
        return super(SSMParameter, self)._should_refresh()

    def _refresh(self):
        """ Force refresh of the configured param names """
        if self._group:
            self._group.refresh()

        items, invalid_parameters = self._get_parameters([self.full_name], self._with_decryption)
        if invalid_parameters or self._name not in items:
            raise InvalidParameterError("%s is invalid. %s - %s" % (self._name, invalid_parameters, items))
        self._value = items[self._name]['Value']
        self._version = items[self._name]['Version']

    @property
    def name(self):
        """ Just an alias """
        return self._name

    @property
    def full_name(self):
        """ name + version """
        if self._version and self._is_pinned_version:
            return "%s:%s" % (self._name, self._version)
        return self._name

    @property
    def version(self):
        """ Just an alias """
        if self._version is None or self._should_refresh():
            self.refresh()
        return self._version

    @property
    def value(self):
        """ The value of a given param name. """
        if self._value is None or self._should_refresh():
            self.refresh()
        return self._value

class SecretsManagerParameter(SSMParameter):
    """ Concrete class for an individual Secrets Manager parameter """

    PREFIX = "/aws/reference/secretsmanager/"

    def __init__(self, param_name, max_age=None, with_decryption=True):
        param_name = self._add_prefix(param_name)
        super(SecretsManagerParameter, self).__init__(param_name, max_age, with_decryption)

    @classmethod
    def _add_prefix(cls, param_name):
        if not param_name:
            raise ValueError("Secret name can't be empty")
        if not param_name.startswith(cls.PREFIX):
            if param_name.startswith('/'):
                raise InvalidParameterError(param_name)
            param_name = "%s%s" % (cls.PREFIX, param_name)
        return param_name

def _batch(iterable, num):
    """Turn an iterable into an iterable of batches of size n (or less, for the last one)"""
    length = len(iterable)
    for ndx in range(0, length, num):
        yield iterable[ndx:min(ndx + num, length)]