#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Convenience utility functions for whoville, not really intended for external use
"""

from __future__ import absolute_import, unicode_literals
import logging
import json
import re
import time
import copy
import base64
import six
from six.moves import reduce
from time import sleep
from datetime import datetime, timedelta
import os
import ruamel.yaml
import requests
from github import Github
from github.GithubException import UnknownObjectException
from requests.models import Response
from whoville import config, security
from pexpect import pxssh
from pexpect.exceptions import EOF
from pexpect.pxssh import ExceptionPxssh

__all__ = ['dump', 'load', 'fs_read', 'fs_write', 'wait_to_complete', 'check_remote_success_file',
           'is_endpoint_up', 'set_endpoint', 'get_val', 'get_remote_shell', 'execute_remote_cmd',
           'load_resources_from_files', 'load_resources_from_github', 'Horton'
           ]

log = logging.getLogger(__name__)
# log.setLevel(logging.DEBUG)


def dump(obj, mode='json'):
    """
    Dumps a native datatype object to json or yaml, defaults to json

    Args:
        obj (varies): The native datatype object to serialise
        mode (str): 'json' or 'yaml', the supported export modes

    Returns (str): The serialised object

    """
    assert mode in ['json', 'yaml']
    try:
        out = json.dumps(
            obj=obj,
            sort_keys=True,
            indent=4
            # default=_json_default
        )
    except TypeError as e:
        raise e
    if mode == 'json':
        return out
    if mode == 'yaml':
        return ruamel.yaml.safe_dump(
            json.loads(out),
            default_flow_style=False
        )
    raise ValueError("Invalid dump Mode specified {0}".format(mode))


def load(obj, dto=None, decode=None):
    """
    Loads a serialised object back into native datatypes, and optionally
    imports it back into the native NiFi DTO

    Warning: Using this on objects not produced by this Package may have
    unintended results! While efforts have been made to ensure that unsafe
    loading is not possible, no stringent security testing has been completed.

    Args:
        obj (dict, list): The serialised object to import
        dto (Optional [tuple{str, str}]): A Tuple describing the service and
        object that should be constructed.

        e.g. dto = ('registry', 'VersionedFlowSnapshot')

    Returns: Either the loaded object in native Python datatypes, or the
        constructed native datatype object

    """
    assert isinstance(obj, (six.string_types, bytes))
    assert dto is None or isinstance(dto, tuple)
    assert decode is None or isinstance(decode, six.string_types)
    # ensure object is standard json before reusing the api_client deserializer
    # safe_load from ruamel.yaml as it doesn't accidentally convert str
    # to unicode in py2. It also manages both json and yaml equally well
    # Good explanation: https://stackoverflow.com/a/16373377/4717963
    # Safe Load also helps prevent code injection
    if decode:
        if decode == 'base64':
            prep_obj = base64.b64decode(obj)
        else:
            raise ValueError("Load's decode option only supports base64")
    else:
        prep_obj = obj
    loaded_obj = ruamel.yaml.safe_load(prep_obj)
    if dto:
        assert dto[0] in ['cloudbreak']
        assert isinstance(dto[1], six.string_types)
        obj_as_json = dump(loaded_obj)
        response = Response()
        response.data = obj_as_json
        api_clients = {
            'cloudbreak': config.cb_config.api_client,
        }
        api_client = api_clients[dto[0]]
        return api_client.deserialize(
            response=response,
            response_type=dto[1]
        )
    return loaded_obj


def fs_write(obj, file_path):
    """
    Convenience function to write an Object to a FilePath

    Args:
        obj (varies): The Object to write out
        file_path (str): The Full path including filename to write to

    Returns: The object that was written
    """
    try:
        with open(str(file_path), 'w') as f:
            f.write(obj)
        return obj
    except TypeError as e:
        raise e


def fs_read(file_path):
    """
    Convenience function to read an Object from a FilePath

    Args:
        file_path (str): The Full path including filename to read from

    Returns: The object that was read
    """
    try:
        with open(str(file_path), 'r') as f:
            return f.read()
    except UnicodeDecodeError:
        with open(str(file_path), 'r', encoding='latin-1') as f:
            return f.read()
    except IOError as e:
        raise e


def wait_to_complete(test_function, *args, **kwargs):
    """
    Implements a basic return loop for a given function which is capable of a
    True|False output

    Args:
        test_function: Function which returns a bool once the target
            state is reached
        delay (int): The number of seconds between each attempt, defaults to
            config.short_retry_delay
        max_wait (int): the maximum number of seconds before issuing a Timeout,
            defaults to config.short_max_wait
        *args: Any args to pass through to the test function
        **kwargs: Any Keyword Args to pass through to the test function

    Returns (bool): True for success, False for not

    """
    log.info("Called wait_to_complete for function %s",
             test_function.__name__)
    delay = kwargs.pop('whoville_delay', config.short_retry_delay)
    max_wait = kwargs.pop('whoville_max_wait', config.short_max_wait)
    timeout = time.time() + max_wait
    while time.time() < timeout:
        log.debug("Calling test_function")
        test_result = test_function(*args, **kwargs)
        log.debug("Checking result")
        if test_result:
            log.debug("Function output [%s] eval to True, returning output",
                     str(test_result)[:25])
            return test_result
        log.debug("Function output [%s] evaluated to False, sleeping...",
                 str(test_result)[:25])
        time.sleep(delay)
    log.debug("Hit Timeout, raising TimeOut Error")
    raise ValueError("Timed Out waiting for {0} to complete".format(
        test_function.__name__))


def is_endpoint_up(endpoint_url, verify=False):
    """
    Tests if a URL is available for requests

    Args:
        endpoint_url (str): The URL to test
        verify (bool): Whether to attempt SSL verification, if SSL needed

    Returns (bool): True for a 200 response, False for not

    """
    log.info("Called is_endpoint_up with args %s", locals())
    try:
        response = requests.get(endpoint_url, verify=verify)
        if response.status_code == 200:
            log.info("Got 200 response from endpoint, returning True")
            return True
        log.info("Got status code %s from endpoint, returning False",
                 response.status_code)
        return False
    except requests.ConnectionError:
        log.info("Got ConnectionError, returning False")
        return False


def get_remote_shell(target_host, sshkey_file=None, user_name=None, wait=True):
    log.info("Getting remote shell for target host [%s]", target_host)
    horton = Horton()
    log.debug("Checking cache for existing Shell session to host")
    shell = horton.shells[target_host] if target_host in horton.shells else None
    if shell:
        if not shell.isalive():
            log.debug("Cached shell is not live, recreating")
            shell = None
        else:
            return shell
    if not shell:
        log.debug("Creating new session")
        sshkey_file = sshkey_file if sshkey_file else config.profile['sshkey_file']
        user_name = user_name if user_name else 'centos'
        while not shell:
            try:
                shell = pxssh.pxssh(options={"StrictHostKeyChecking": "no", "UserKnownHostsFile": "/dev/null"})
                shell.login(target_host, user_name, ssh_key=sshkey_file, check_local_ip=False)
            except (ExceptionPxssh, EOF):
                if not wait:
                    log.info("Target host is not accepting the connection, Wait is not set, returning False...")
                    return False
                else:
                    log.info("Retrying until target host accepts the connection request...")
                    sleep(5)
            horton.shells[target_host] = shell
            log.info("Returning Shell session...")
    return shell


def execute_remote_cmd(target_host, cmd, expect=None, repeat=False, bool_response=False):
    log.info("Executing remote command [%s] on host [%s] expecting output of [%s] with wait-repeat of [%s] and "
             "bool_response of [%s]", cmd[:100], target_host, str(expect), str(repeat), str(bool_response))
    assert isinstance(cmd, six.string_types)
    assert expect is None or isinstance(expect, six.string_types)
    assert isinstance(repeat, bool)
    assert isinstance(bool_response, bool)
    if bool_response and not expect:
        raise ValueError("Must include an Expect statement with bool_response test")
    s = get_remote_shell(target_host, wait=not bool_response)
    if not s:
        if bool_response:
            log.info("Remote Shell not currently available, bool_respose is set, returning False")
            return False
        else:
            raise ValueError('Remote Shell not available to host [%s]', target_host)
    log.debug("Issuing command [%s]", cmd)
    s.sendline(cmd)
    s.prompt()
    if not expect:
        log.info("Expect not set, returning command result...")
        return s.before.decode()
    while expect not in s.before.decode():
        if bool_response:
            return False
        log.info("Expect set and string not found in response, waiting...")
        sleep(3)
        s.prompt()
        if repeat:
            log.info("Repeat set, reissuing command before checking again")
            s.sendline(cmd)
    log.info("Expect set and found in command response, returning response")
    return s.before.decode()


def check_remote_success_file(target_host, check_file='/tmp/status.success'):

    response = execute_remote_cmd(target_host, 'cat ' + check_file)
    if 'complete' in response:
        log.info("Found complete in .success file, ready to proceed")
        return True
    else:
        log.info("Could not find .success file")
        return False


def set_endpoint(endpoint_url):
    """
    EXPERIMENTAL

    Sets the endpoint when switching between instances of NiFi or other
    projects. Not tested extensively with secured instances.

    Args:
        endpoint_url (str): The URL to set as the endpoint. Autodetects the
        relevant service e.g. 'http://localhost:18080/nifi-registry-api'

    Returns (bool): True for success, False for not
    """
    log.info("Called set_endpoint with args %s", locals())
    if 'cb/api' in endpoint_url:
        log.debug("Setting Cloudbreak endpoint to %s", endpoint_url)
        this_config = config.cb_config
    elif ':7189' in endpoint_url:
        log.debug("Setting Altus Director endpoint to %s", endpoint_url)
        this_config = config.cd_config
    else:
        raise ValueError("Unrecognised API Endpoint")
    try:
        if this_config.api_client:
            log.debug("Found Active API Client, updating...")
            this_config.api_client.host = endpoint_url
    except AttributeError:
        log.debug("No Active API Client found to update")
    this_config.host = endpoint_url
    if this_config.host == endpoint_url:
        return True
    return False


# https://stackoverflow.com/a/36584863/4717963
# https://stackoverflow.com/a/14692747/4717963
def get_val(root, items, sep='.', **kwargs):
    """
    Swagger client objects don't behave like dicts, so need a custom func
    to step down through keys when defined as string vars etc.

    Warnings:
        If you try to retrieve a key that doesn't exist you will get None
        instead of an Attribute Error. Code defensively, or abuse it, whatever.

    Args:
        root [dict, client obj]: The dict or Object to recurse through
        items (list, str): either list or dot notation string of keys to walk
            through
        sep (str): The character expected as a separator when parsing strings

    Returns (varies): The target val at the last key

    """
    assert isinstance(items, (list, six.string_types))
    for key in items if isinstance(items, list) else items.split(sep):
        if root is None:
            return root
        elif isinstance(root, list):
            if '|' not in key:
                raise ValueError("Found list but key {0} does not match list "
                                 "filter format 'x|y'".format(key))
            field, value = key.split('|')
            list_filter = [x for x in root if x.get(field) == value]
            if list_filter:
                root = list_filter[0]
        elif isinstance(root, dict):
            root = root.get(key)
        else:
            root = root.__getattribute__(key)
    return root


def set_val(root, keys, val, sep='.', merge=False, ignore_keys=None,
            squash_keys=None, max_depth=50):
    assert isinstance(keys, (list, six.string_types))
    if isinstance(keys, six.string_types):
        log.debug("got keys as string, splitting using sep [%s]", sep)
        keys = keys.split(sep)
    log.debug("keys are [%s]", str(keys))
    last_key = keys.pop()
    log.debug("grabbing last key [%s] off the end", last_key)
    root = get_val(root, keys, sep)
    log.debug("Got root from keys [%s]", str(keys))
    if not merge:
        log.debug("not merge update, last key is [%s], replacing", last_key)
        root[last_key] = copy.deepcopy(val)
    else:
        log.debug("running merge update on root like [%s] with value like [%s]",
                  str(root)[:100], str(val)[:100])
        merged = deep_merge(
            target=root[last_key], source=copy.deepcopy(val),
            ignore_keys=ignore_keys, squash_keys=squash_keys,
            max_depth=max_depth)
        log.debug("replacing original root at [%s] with merged root like [%s]",
                  last_key, str(merged)[:100])
        root[last_key] = merged


# https://stackoverflow.com/a/18394648/4717963
# This cannot handle nested lists
def deep_merge(target, source, ignore_keys=None, squash_keys=None, depth=0,
               max_depth=50):
    for k, v in source.items():
        if ignore_keys and k in ignore_keys:
            log.debug("k [%s] is on ignore list, skipping", k)
        elif depth == max_depth:
            log.debug("hit max merge depth, squashing k [%s] with new v like "
                      "[%s]", k, str(v)[:100])
            target[k] = v
        elif squash_keys and k in squash_keys:
            log.debug("squashing k [%s] with new v like [%s]", k, str(v)[:100])
            target[k] = v
        elif isinstance(v, dict) and v:
            log.debug("Running recursive update on k [%s]", k)
            target[k] = deep_merge(
                target=target.get(k, {}), source=v, ignore_keys=ignore_keys, 
                squash_keys=squash_keys, depth=depth+1, max_depth=max_depth)
        elif isinstance(v, list):
            log.debug("Merging list under k [%s]", k)
            target[k] = (target.get(k, []) + v)
        else:
            log.debug("simple value, updating [%s] with value like [%s]",
                      k, str(v)[:100])
            target[k] = v
    log.debug("Returning merged object")
    return target


def load_resources_from_github(repo_name, username, token, tgt_dir, ref='master',
                    recurse=True):

    def _recurse_github_dir(g_repo, r_tgt, r_ref):
        contents = g_repo.get_dir_contents(r_tgt, r_ref)
        out = {}
        for obj in contents:
            log.info("loading " + os.sep.join([r_tgt, r_ref, obj.name]))
            if obj.type == 'dir':
                out[obj.name] = _recurse_github_dir(g_repo, obj.path, r_ref)
            elif obj.type == 'file':
                if obj.name.rsplit('.')[1] not in ['yaml', 'json']:
                    out[obj.name] = obj.decoded_content.decode('utf-8')
                else:
                    out[obj.name] = load(obj.decoded_content)
        return out

    try:
        g_accnt = Github(username, token)
    except UnknownObjectException:
        raise ValueError("Github Login failure - please check you have access "
                         "to Repo %s and your token is correctly setup",
                         tgt_dir)
    g_repo = g_accnt.get_repo(repo_name)
    if not recurse:
        listing = g_repo.get_dir_contents(tgt_dir, ref)
        return listing
    return _recurse_github_dir(g_repo, tgt_dir, ref)


def load_resources_from_files(file_path):
    resources = {}
    # http://code.activestate.com/recipes/577879-create-a-nested-dictionary-from-oswalk/
    rootdir = file_path.rstrip(os.sep)
    log.debug("Trying path {0}".format(rootdir))
    head = rootdir.rsplit(os.sep)[-1]
    start = rootdir.rfind(os.sep) + 1
    for path, dirs, files in os.walk(rootdir):
        log.debug("Trying path {0}".format(path))
        folders = path[start:].split(os.sep)
        subdir = dict.fromkeys(files)
        parent = reduce(dict.get, folders[:-1], resources)
        parent[folders[-1]] = subdir
        for file_name in subdir.keys():
            if file_name[0] == '.':
                log.debug("skipping dot file [%s]", file_name)
            else:
                log.debug("loading [%s]", os.path.join(path, file_name))
            if file_name.rsplit('.')[1] not in ['yaml', 'json']:
                subdir[file_name] = fs_read(os.path.join(path, file_name))
            else:
                # Valid yaml can't have tabs, only spaces
                # proactively replacing tabs as some tools do it wrong
                subdir[file_name] = load(
                    fs_read(os.path.join(
                        path, file_name
                    ))
                )
    return resources[head]


def singleton(cls, *args, **kw):
    instances = {}

    def _singleton():
        if cls not in instances:
            instances[cls] = cls(*args, **kw)
        return instances[cls]
    return _singleton


@singleton
class Horton:
    """
    Borg Singleton to share state between the various processes.
    Looks complicated, but it makes the rest of the code more readable for
    Non-Python natives.
    ...
    Why Horton? Because an Elephant Never Forgets
    """
    def __init__(self):
        self.cbd = None  # Server details for orchestration host
        self.cbcred = None  # Credential for deployments, once loaded in CB
        self.cdcred = None  # Credential for deployments, once loaded in CD
        self.cad = None  # Client for Altus Director, once created
        self.k8svm = {}  # Reference for K8s environment, once created
        self.resources = {}  # all loaded resources from github/files
        self.defs = {}  # deployment definitions, once pulled from resources
        self.specs = {}  # stack specifications, once formulated
        self.stacks = {}  # stacks deployed, once submitted
        self.deps = {}  # Dependencies loaded for a given Definition
        self.seq = {}  # Prioritised list of tasks to execute
        self.cache = {}  # Key:Value store for passing params between Defs
        self.shells = {}  # Key:Value session store for remote shells
        self.namespace = config.profile['namespace']
        self.global_purge = config.profile['globalpurge'] if 'globalpurge' in config.profile else False

    def __iter__(self):
        for attr, value in self.__dict__.items():
            yield attr, value

    def _getr(self, keys, sep=':', **kwargs):
        """
        Convenience function to retrieve params in a very readable method

        Args:
            keys (str): dot notation string of the key for the value to be
                retrieved. e.g 'secret.cloudbreak.hostname'

        Returns:
            The value if found, or None if not
        """
        return get_val(self, keys, sep, **kwargs)

    def _setr(self, keys, val, sep=':', **kwargs):
        set_val(self, keys, val, sep, **kwargs)


def validate_profile():
    log.info("Validating provided profile.yml")
    horton = Horton()
    # TODO: Check VPN if OpenStack
    # Check Profile is imported
    if not config.profile:
        raise ValueError("whoville Config Profile is not populated with"
                         "deployment controls, cannot proceed")
    # Check Profile version
    if 'profilever' not in config.profile:
        raise ValueError("Your Profile is out of date, please recreate your "
                         "Profile from the template")
    if config.profile['profilever'] < config.min_profile_ver:
        raise ValueError("Your Profile is out of date, please recreate your "
                         "Profile from the template. Profile v3 requires an ssh private key or pem file.")
    # Handle SSH
    if 'sshkey_file' in config.profile and config.profile['sshkey_file']:
        assert config.profile['sshkey_file'].endswith('.pem')
        from Crypto.PublicKey import RSA
        pem_key = RSA.importKey(fs_read(config.profile['sshkey_file']))
        config.profile['sshkey_pub'] = pem_key.publickey().exportKey(format="OpenSSH").decode()
        config.profile['sshkey_priv'] = pem_key.exportKey().decode()
        config.profile['sshkey_name'] = os.path.basename(config.profile['sshkey_file']).split('.')[0]
    else:
        assert any(k in config.profile for k in ['ssh_key_priv', 'sshkey_priv'])
        assert all(k in config.profile for k in ['sshkey_pub', 'sshkey_name'])
    # Check Namespace
    assert isinstance(horton.namespace, six.string_types),\
        "Namespace must be string"
    assert len(horton.namespace) >= 2,\
        "Namespace must be at least 2 characters"
    # Check Password
    if 'password' in config.profile and config.profile['password']:
        horton.cache['ADMINPASSWORD'] = config.profile['password']
    else:
        horton.cache['ADMINPASSWORD'] = security.get_secret('ADMINPASSWORD')
    password_test = re.compile(r'^(?=.*[A-Za-z])(?=.*\d)[A-Za-z\d-]{12,}$')
    if not bool(password_test.match(horton.cache['ADMINPASSWORD'])):
        raise ValueError("Password doesn't match Platform spec."
                         "Requires 12+ characters, at least 1 letter and "
                         "number, may also contain -")
    # Check Provider
    platform = config.profile.get('platform')
    assert platform['provider'] in ['EC2', 'AZURE_ARM', 'GCE', 'OPENSTACK']
    if platform['provider'] == 'GCE':
        if 'apikeypath' in platform:
            with open(platform['apikeypath'], "r") as apikey:
                platform['jsonkey'] = apikey.read()
    # TODO: Read in the profile template, check it has all matching keys
    # Check Profile Namespace is valid
    ns_test = re.compile(r'[a-z0-9-]')
    if not bool(ns_test.match(horton.namespace)):
        raise ValueError("Namespace must only contain 0-9 a-z -")
    # Check storage bucket matches expected format
    if 'bucket' in config.profile:
        if platform['provider'] == 'EC2':
            bucket_test = re.compile(r'[a-z0-9.-]')
        elif platform['provider'] == 'AZURE_ARM':
            bucket_test = re.compile(r'[a-z0-9@]')
        elif platform['provider'] == 'GCE':
            bucket_test = re.compile(r'[a-z0-9.-]')
        else:
            raise ValueError("bucket listed in Profile but Platform Provider not supported")
        if not bool(bucket_test.match(config.profile['bucket'])):
            raise ValueError("Bucket name doesn't match Platform spec")
    # check tags
    if 'tags' not in config.profile:
        raise ValueError("Profile is missing mandatory tags entries")
    tags = config.profile['tags']
    for tag in ['owner']:
        assert tag in tags, "tag {0} missing from profile".format(tag)
        assert isinstance(tag, six.string_types) and len(tag) > 3, "Tag {0} must be a string over 3 chars".format(tag)


def resolve_tags(instance_name, owner):
    tags = config.profile.get('tags')
    if tags is not None:
        if 'owner' not in tags or tags['owner'] is None:
            tags['owner'] = owner
        if 'startdate' not in tags or tags['startdate'] is None:
            tags['startdate'] = str(datetime.now().strftime("%m%d%Y").lower())
        if 'enddate' not in tags or tags['enddate'] is None:
            tags['enddate'] = str(
                (datetime.now() + timedelta(days=2)).strftime("%m%d%Y").lower())
        if 'project' not in tags or tags['project'] is None:
            tags['project'] = 'selfdevelopment'
        if 'deploytool' not in tags or tags['deploytool'] is None:
            tags['deploytool'] = 'whoville'
        tags['dps'] = 'false'
        tags['datalake'] = 'false'
    else:
        tags = {'datalake': 'false', 'dps': 'false'}

    if 'dps' in instance_name:
        tags['dps'] = 'true'
    if 'datalake' in instance_name:
        tags['datalake'] = 'true'
    return tags