# -*- coding: utf-8 -*-
#
# Copyright 2017, Rambler Digital Solutions
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from __future__ import absolute_import, division, print_function, unicode_literals

import json
import shlex
import subprocess
import sys
from itertools import chain

import jinja2
import yaml
from jinja2.ext import Extension
from jinja2.lexer import Token
from trafaret import str_types

from .compat import Iterable, Mapping
from .schema import dump, ensure_schema


try:
    from functools import reduce
except ImportError:
    pass


def yaml_filter(obj):
    if isinstance(obj, str_types):
        return obj
    elif isinstance(obj, jinja2.Undefined):
        return ""
    else:
        try:
            return dump(obj)
        except Exception as exc:
            raise RuntimeError(
                "Unable to serialize {!r} to YAML because {}."
                "Template render must produce valid YAML file, so please use"
                " simple types in `with_items` block."
                "".format(obj, exc)
            )


class YamlExtension(Extension):
    def filter_stream(self, stream):
        """
        We convert
        {{ some.variable | filter1 | filter 2}}
            to
        {{ some.variable | filter1 | filter 2 | yaml}}

        ... for all variable declarations in the template

        This function is called by jinja2 immediately
        after the lexing stage, but before the parser is called.
        """
        while not stream.eos:
            token = next(stream)
            if token.test("variable_begin"):
                var_expr = []
                while not token.test("variable_end"):
                    var_expr.append(token)
                    token = next(stream)
                variable_end = token

                last_token = var_expr[-1]
                if last_token.test("name") and last_token.value == "yaml":
                    # don't yaml twice
                    continue

                # Wrap the whole expression between the `variable_begin`
                # and `variable_end` marks in parens:
                var_expr.insert(1, Token(var_expr[0].lineno, "lparen", None))
                var_expr.append(Token(var_expr[-1].lineno, "rparen", None))

                var_expr.append(Token(token.lineno, "pipe", "|"))
                var_expr.append(Token(token.lineno, "name", "yaml"))

                var_expr.append(variable_end)

                for token in var_expr:
                    yield token
            else:
                yield token


ENV = jinja2.Environment()
ENV.filters["yaml"] = yaml_filter
ENV.add_extension(YamlExtension)


def transform(schema):
    schema0 = ensure_schema(schema)
    schema1 = transform_templates(schema0)
    schema2 = transform_defaults(schema1)
    return schema2


def transform_templates(schema):
    for dag_id, dag_schema in schema["dags"].items():
        if "do" not in dag_schema:
            continue
        templates = dag_schema.pop("do")
        for template in templates:
            transform_strategy(dag_schema, template)
    return schema


def transform_strategy(schema, template):
    if "with_items" in template:
        return transform_with_items(schema, template)
    else:
        raise RuntimeError("cannot figure how to apply template: {}".format(template))


def transform_with_items(schema, template):
    items = template["with_items"]
    if isinstance(items, dict):
        if set(items) == {"using"}:
            items = items["using"]
        elif set(items) == {"from_stdout"}:
            items = from_stdout(items["from_stdout"])
    if hasattr(items, "__call__"):
        items = items()
    if not isinstance(items, Iterable):
        raise RuntimeError("bad with_items template: {}".format(items))
    for key in {"operators", "sensors", "flow"}:
        if key not in template:
            continue
        subschema = reduce(merge, transform_schema_with_items(template[key], items), {})
        schema.setdefault(key, {})
        schema[key] = merge(schema[key], subschema)
    return schema


def from_stdout(cmd):
    PY2 = sys.version_info[0] == 2
    if PY2:
        cmd = cmd.encode("utf8")
    output = subprocess.check_output(shlex.split(cmd))
    if not PY2:
        output = output.decode("utf8")
    return json.loads(output)


def transform_schema_with_items(schema, items):
    return [transform_dict_with_item(schema, item) for item in items]


def transform_value_with_item(value, item):
    if isinstance(value, Mapping):
        return transform_dict_with_item(value, item)
    elif isinstance(value, str_types):
        return transform_string_with_item(value, item)
    elif isinstance(value, Iterable):
        return transform_list_with_item(value, item)
    else:
        return value


def transform_dict_with_item(dict, item):
    result = {}
    for key, value in dict.items():
        key = transform_value_with_item(key, item)
        value = transform_value_with_item(value, item)
        result[key] = value
    return result


def transform_list_with_item(list, item):
    return [transform_value_with_item(value, item) for value in list]


def transform_string_with_item(string, item, env=ENV):
    # That's not very cool, but at least this ensures that users won't send
    # us arbitrary objects and will stay withof simple and clean data types.
    return yaml.safe_load(env.from_string(string).render(item=item))


def merge(base, other):
    if isinstance(base, Mapping) and isinstance(other, Mapping):
        return merge_mappings(base, other)
    elif isinstance(base, str_types) and isinstance(other, str_types):
        return base
    elif isinstance(base, Iterable) and isinstance(other, Iterable):
        return merge_iterable(base, other)
    else:
        return base


def merge_mappings(base, other):
    result = dict(**base)
    for key in other:
        if key not in base:
            result[key] = other[key]
        elif base[key] == other[key]:
            continue
        else:
            result[key] = merge(base[key], other[key])
    return result


def merge_iterable(base, other):
    return list(chain(base, other))


def transform_defaults(schema):
    for dag_id, dag_schema in schema["dags"].items():
        defaults = dag_schema.pop("defaults", {})
        if not defaults:
            continue
        for key in {"sensors", "operators"}:
            if key in dag_schema and key in defaults:
                transform_apply_tasks_defaults(dag_schema[key], defaults[key])
    return schema


def transform_apply_tasks_defaults(tasks, defaults):
    for task_id, task_schema in tasks.items():
        tasks[task_id] = transform_apply_task_defaults(task_schema, defaults)


def transform_apply_task_defaults(task, defaults):
    return merge_mappings(task, defaults)