import html
import importlib
import re
import token
import tokenize

from io import StringIO
from pathlib import Path
from collections import defaultdict, deque, namedtuple, ChainMap
from typing import Dict, Iterable

TOK_COMMENT = "comment"
TOK_TEXT = "text"
TOK_VAR = "var"
TOK_BLOCK = "block"

tag_re = re.compile(r"{%\s*(?P<block>.+?)\s*%}|{{\s*(?P<var>.+?)\s*}}|{#\s*(?P<comment>.+?)\s*#}", re.DOTALL)
Token = namedtuple("Token", "type content")

class SafeStr(str):
    __safe__ = True

    def __str__(self):
        return self

def tokenise(template):
    upto = 0
    for m in tag_re.finditer(template):
        start, end = m.span()
        if upto < start:
            yield Token(TOK_TEXT, template[upto:start])
        upto = end
        mode = m.lastgroup
        yield Token(mode, m[mode].strip())
    if upto < len(template):
        yield Token(TOK_TEXT, template[upto:])

class TemplateLoader(dict):
    def __init__(self, paths):
        self.paths = [Path(path).resolve() for path in paths]

    def load(self, name, encoding="utf8"):
        for path in self.paths:
            full_path = path / name
            if full_path.is_file():
                return Template(full_path.read_text(encoding), loader=self, name=name)
        raise LookupError(name)

    def __missing__(self, key):
        self[key] = tmpl = self.load(key)
        return tmpl

class Context(ChainMap):
    def __init__(self, *args, escape=html.escape):
        self.maps.append({"True": True, "False": False, "None": None})
        self.escape = escape

    def push(self, data=None):
        self.maps.insert(0, data or {})
        return self

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, tb):

class Nodelist(list):
    def render(self, context, output):
        for node in self:
            node.render(context, output)

    def nodes_by_type(self, node_type):
        for node in self:
            if isinstance(node, node_type):
                yield node
            if isinstance(node, BlockNode):
                yield from node.nodes_by_type(node_type)

class Template:
    def __init__(self, src, loader=None, name=None):
        self.tokens, self.loader = tokenise(src), loader = name  # So we can report where the fault was
        self.nodelist = self.parse_nodelist([])

    def parse(self):
        for tok in self.tokens:
            if tok.type == TOK_TEXT:
                yield TextTag(tok.content)
            elif tok.type == TOK_VAR:
                yield VarTag(tok.content)
            elif tok.type == TOK_BLOCK:
                m = re.match(r"\w+", tok.content)
                if not m:
                    raise SyntaxError(tok)
                yield BlockNode.__tags__[].parse(tok.content[m.end(0):].strip(), self)

    def parse_nodelist(self, ends):
        nodelist = Nodelist()
            node = next(self.parse())
            while not in ends:
                node = next(self.parse())
        except StopIteration:
            node = None
        nodelist.endnode = node
        return nodelist

    def render(self, context, output=None):
        if not isinstance(context, Context):
            context = Context(context)
        if output is None:
            dest = StringIO()
            dest = output
        self.nodelist.render(context, dest)
        if output is None:
            return dest.getvalue()

class AstLiteral:
    def __init__(self, arg):
        self.arg = arg

    def resolve(self, context):
        return self.arg

class AstContext:
    def __init__(self, arg):
        self.arg = arg

    def resolve(self, context):
        return context.get(self.arg, "")

class AstLookup:
    def __init__(self, left, right):
        self.left = left
        self.right = right

    def resolve(self, context):
        left = self.left.resolve(context)
        right = self.right.resolve(context)

        return left[right]

class AstAttr:
    def __init__(self, left, right):
        self.left = left
        self.right = right

    def resolve(self, context):
        left = self.left.resolve(context)

        return getattr(left, self.right, "")

class AstCall:
    def __init__(self, func):
        self.func = func
        self.args = []

    def add_arg(self, arg):

    def resolve(self, context):
        func = self.func.resolve(context)
        args = [arg.resolve(context) for arg in self.args]

        return func(*args)

class Expression:
    def __init__(self, source):
        self.tokens = tokenize.generate_tokens(StringIO(source).readline)  # prime the first token

    def next(self):
        self.current = next(self.tokens)
        return self.current

    def parse(s):
        p = Expression(s)
        result = p._parse()

        if p.current.exact_type not in (token.NEWLINE, token.ENDMARKER):
            raise SyntaxError(f"Parse ended unexpectedly: {p.current}")

        return result

    def parse_kwargs(self):
        kwargs = {}

        tok = self.current

        while tok.exact_type != token.ENDMARKER:
            if tok.exact_type == token.NEWLINE:
                tok =

            if tok.exact_type != token.NAME:
                raise SyntaxError(f"Expected name, found {tok}")
            name = tok.string
            tok =

            if tok.exact_type != token.EQUAL:
                raise SyntaxError(f"Expected =, found {tok}")
            tok =

            kwargs[name] = self._parse()

            tok =

        return kwargs

    def _parse(self):
        tok = self.current

        if tok.exact_type in (token.ENDMARKER, token.COMMA):
            return  # TODO

        if tok.exact_type == token.STRING:
            return AstLiteral(tok.string[1:-1])

        if tok.exact_type == token.NUMBER:
                value = int(tok.string)
            except ValueError:
                value = float(tok.string)
            return AstLiteral(value)

        if tok.exact_type == token.NAME:
            state = AstContext(tok.string)

            while True:
                tok =

                if tok.exact_type == token.DOT:
                    tok =
                    if tok.exact_type != token.NAME:
                        raise SyntaxError(f"Invalid attr lookup: {tok}")
                    state = AstAttr(state, tok.string)

                elif tok.exact_type == token.LSQB:
                    right = self._parse()
                    state = AstLookup(state, right)
                    if self.current.exact_type != token.RSQB:
                        raise SyntaxError(f"Expected ] but found {self.current}")

                elif tok.exact_type == token.LPAR:
                    state = AstCall(state)
                    while self.current.exact_type != token.RPAR:
                        arg = self._parse()
                        if self.current.exact_type != token.COMMA:

                    if self.current.exact_type != token.RPAR:
                        raise SyntaxError(f"Expected ( but found {self.current}")



            return state

        raise SyntaxError(
            f"Error parsing expression {tok.line !r}: Unexpected token {tok.string!r} at position {tok.start[0]}."

class Node:
    name = None

    def __init__(self, content):
        self.content = content

    def render(self, context, output):
        pass  # pragma: no cover

class TextTag(Node):
    def render(self, context, output):

class VarTag(Node):
    def __init__(self, content):
        self.expr = Expression.parse(content)

    def render(self, context, output):
        value = str(self.expr.resolve(context))
        if not getattr(value, '__safe__', False):
            value = context.escape(value)

class BlockNode(Node):
    __tags__ : Dict[str, 'BlockNode'] = {}
    child_nodelists: Iterable[str] = ("nodelist",)

    def __init_subclass__(cls, *, name):
        super().__init_subclass__() = name
        BlockNode.__tags__[name] = cls
        return cls

    def parse(cls, content, parser):
        return cls(content)

    def nodes_by_type(self, node_type):
        for attr in self.child_nodelists:
            nodelist = getattr(self, attr, None)
            if nodelist:
                yield from nodelist.nodes_by_type(node_type)

class ForTag(BlockNode, name="for"):
    child_nodelists = ("nodelist", "elselist")

    def __init__(self, argname, iterable, nodelist, elselist):
        self.argname, self.iterable, self.nodelist, self.elselist = argname, iterable, nodelist, elselist

    def parse(cls, content, parser):
        argname, iterable = content.split(" in ", 1)
        nodelist = parser.parse_nodelist({"endfor", "else"})
        elselist = parser.parse_nodelist({"endfor"}) if == "else" else None
        return cls(argname.strip(), Expression.parse(iterable.strip()), nodelist, elselist)

    def render(self, context, output):
        iterable = self.iterable.resolve(context)
        if iterable:
            with context.push():
                for idx, item in enumerate(iterable):
                    context.update({"loopcounter": idx, self.argname: item})
                    self.nodelist.render(context, output)
        elif self.elselist:
            self.elselist.render(context, output)

class ElseTag(BlockNode, name="else"):

class EndforTag(BlockNode, name="endfor"):

class IfTag(BlockNode, name="if"):
    child_nodelists = ("nodelist", "elselist")

    def __init__(self, condition, nodelist, elselist):
        condition, inv = re.subn(r"^not\s+", "", condition, count=1)
        self.inv, self.condition = bool(inv), Expression.parse(condition)
        self.nodelist, self.elselist = nodelist, elselist

    def parse(cls, content, parser):
        nodelist = parser.parse_nodelist({"endif", "else"})
        elselist = parser.parse_nodelist({"endif"}) if == "else" else None
        return cls(content, nodelist, elselist)

    def render(self, context, output):
        if self.test_condition(context):
            self.nodelist.render(context, output)
        elif self.elselist:
            self.elselist.render(context, output)

    def test_condition(self, context):
        return self.inv ^ bool(self.condition.resolve(context))

class EndifTag(BlockNode, name="endif"):

class IncludeTag(BlockNode, name="include"):
    def __init__(self, template_name, kwargs, loader):
        self.template_name, self.kwargs, self.loader = template_name, kwargs, loader

    def parse(cls, content, parser):
        if parser.loader is None:
            raise RuntimeError("Can't use {% include %} without a bound Loader")
        tokens = Expression(content)
        template_name = tokens._parse()
        kwargs = tokens.parse_kwargs()
        return cls(template_name, kwargs, parser.loader)

    def render(self, context, output):
        name = self.template_name.resolve(context)
        tmpl = self.loader[name]
        kwargs = {key: expr.resolve(context) for key, expr in self.kwargs.items()}
        ctx = context.new_child(kwargs)
        tmpl.render(ctx, output)

class LoadTag(BlockNode, name="load"):
    def parse(cls, content, parser):
        return cls(None)

class ExtendsTag(BlockNode, name="extends"):
    def __init__(self, parent, loader, nodelist):
        self.parent, self.loader, self.nodelist = parent, loader, nodelist

    def parse(cls, content, parser):
        parent = Expression.parse(content)
        nodelist = parser.parse_nodelist([])
        return cls(parent, parser.loader, nodelist)

    def render(self, context, output):
        parent = self.loader[self.parent.resolve(context)]
        block_context = getattr(context, "block_context", None)
        if block_context is None:
            block_context = context.block_context = defaultdict(deque)
        for block in self.nodelist.nodes_by_type(BlockTag):
        if parent.nodelist[0].name != "extends":
            for block in parent.nodelist.nodes_by_type(BlockTag):
        parent.render(context, output)

class BlockTag(BlockNode, name="block"):
    def __init__(self, name, nodelist):
        self.block_name, self.nodelist = name, nodelist

    def parse(cls, content, parser):
        m = re.match(r"\w+", content)
        if not m:
            raise ValueError(f'Invalid block label: {content !r}')
        name =
        nodelist = parser.parse_nodelist({"endblock"})
        return cls(name, nodelist)

    def render(self, context, output):
        self.context = context
        self.output = output

    def _render(self):
        block_context = getattr(self.context, "block_context", None)
        if not block_context:
            block = self
            block = block_context[self.block_name].popleft()
        with self.context.push({"block": self}):
            block.nodelist.render(self.context, self.output)
        if block_context:

    def super(self):
        return ""

class EndBlockTag(BlockNode, name="endblock"):

class WithTag(BlockNode, name="with"):
    def __init__(self, kwargs, nodelist):
        self.kwargs, self.nodelist = kwargs, nodelist

    def parse(cls, content, parser):
        kwargs = Expression(content).parse_kwargs()
        nodelist = parser.parse_nodelist({"endwith"})
        return cls(kwargs, nodelist)

    def render(self, context, output):
        kwargs = {key: value.resolve(context) for key, value in self.kwargs.items()}
        with context.push(kwargs):
            self.nodelist.render(context, output)

class EndWithTag(BlockNode, name="endwith"):

class CaseTag(BlockNode, name="case"):
    def __init__(self, term, nodelist):
        self.term, self.nodelist = term, nodelist

    def parse(cls, content, parser):
        term = Expression.parse(content)
        nodelist = parser.parse_nodelist(["endcase"])
        else_found = False
        for node in nodelist:
            if not in {"when", "else"}:
                raise SyntaxError(f"Only 'when' and 'else' allowed as children of case. Found: {node}")
            if == "else":
                if else_found:
                    raise SyntaxError("Case tag can only have one else child")
                else_found = True
        nodelist.sort(key=lambda x:, reverse=True)
        return cls(term, nodelist)

    def render(self, context, output):
        value = self.term.resolve(context)
        for node in self.nodelist:
            if == "when":
                other = node.term.resolve(context)
                other = value
            if value == other:
                node.render(context, output)

class WhenTag(BlockNode, name="when"):
    def __init__(self, term, nodelist):
        self.term, self.nodelist = term, nodelist

    def parse(cls, content, parser):
        term = Expression.parse(content)
        nodelist = parser.parse_nodelist()
        return cls(term, nodelist)

    def render(self, context, output):
        self.nodelist.render(context, output)

class EndCaseTag(BlockNode, name="endcase"):