import html import importlib import re import token import tokenize from io import StringIO from pathlib import Path from collections import defaultdict, deque, namedtuple, ChainMap from typing import Dict, Iterable TOK_COMMENT = "comment" TOK_TEXT = "text" TOK_VAR = "var" TOK_BLOCK = "block" tag_re = re.compile(r"{%\s*(?P<block>.+?)\s*%}|{{\s*(?P<var>.+?)\s*}}|{#\s*(?P<comment>.+?)\s*#}", re.DOTALL) Token = namedtuple("Token", "type content") class SafeStr(str): __safe__ = True def __str__(self): return self def tokenise(template): upto = 0 for m in tag_re.finditer(template): start, end = m.span() if upto < start: yield Token(TOK_TEXT, template[upto:start]) upto = end mode = m.lastgroup yield Token(mode, m[mode].strip()) if upto < len(template): yield Token(TOK_TEXT, template[upto:]) class TemplateLoader(dict): def __init__(self, paths): self.paths = [Path(path).resolve() for path in paths] def load(self, name, encoding="utf8"): for path in self.paths: full_path = path / name if full_path.is_file(): return Template(full_path.read_text(encoding), loader=self, name=name) raise LookupError(name) def __missing__(self, key): self[key] = tmpl = self.load(key) return tmpl class Context(ChainMap): def __init__(self, *args, escape=html.escape): super().__init__(*args) self.maps.append({"True": True, "False": False, "None": None}) self.escape = escape def push(self, data=None): self.maps.insert(0, data or {}) return self def __enter__(self): return self def __exit__(self, exc_type, exc_value, tb): self.maps.pop(0) class Nodelist(list): def render(self, context, output): for node in self: node.render(context, output) def nodes_by_type(self, node_type): for node in self: if isinstance(node, node_type): yield node if isinstance(node, BlockNode): yield from node.nodes_by_type(node_type) class Template: def __init__(self, src, loader=None, name=None): self.tokens, self.loader = tokenise(src), loader self.name = name # So we can report where the fault was self.nodelist = self.parse_nodelist([]) def parse(self): for tok in self.tokens: if tok.type == TOK_TEXT: yield TextTag(tok.content) elif tok.type == TOK_VAR: yield VarTag(tok.content) elif tok.type == TOK_BLOCK: m = re.match(r"\w+", tok.content) if not m: raise SyntaxError(tok) yield BlockNode.__tags__[m.group(0)].parse(tok.content[m.end(0):].strip(), self) def parse_nodelist(self, ends): nodelist = Nodelist() try: node = next(self.parse()) while node.name not in ends: nodelist.append(node) node = next(self.parse()) except StopIteration: node = None nodelist.endnode = node return nodelist def render(self, context, output=None): if not isinstance(context, Context): context = Context(context) if output is None: dest = StringIO() else: dest = output self.nodelist.render(context, dest) if output is None: return dest.getvalue() class AstLiteral: def __init__(self, arg): self.arg = arg def resolve(self, context): return self.arg class AstContext: def __init__(self, arg): self.arg = arg def resolve(self, context): return context.get(self.arg, "") class AstLookup: def __init__(self, left, right): self.left = left self.right = right def resolve(self, context): left = self.left.resolve(context) right = self.right.resolve(context) return left[right] class AstAttr: def __init__(self, left, right): self.left = left self.right = right def resolve(self, context): left = self.left.resolve(context) return getattr(left, self.right, "") class AstCall: def __init__(self, func): self.func = func self.args = [] def add_arg(self, arg): self.args.append(arg) def resolve(self, context): func = self.func.resolve(context) args = [arg.resolve(context) for arg in self.args] return func(*args) class Expression: def __init__(self, source): self.tokens = tokenize.generate_tokens(StringIO(source).readline) self.next() # prime the first token def next(self): self.current = next(self.tokens) return self.current @staticmethod def parse(s): p = Expression(s) result = p._parse() if p.current.exact_type not in (token.NEWLINE, token.ENDMARKER): raise SyntaxError(f"Parse ended unexpectedly: {p.current}") return result def parse_kwargs(self): kwargs = {} tok = self.current while tok.exact_type != token.ENDMARKER: if tok.exact_type == token.NEWLINE: tok = self.next() continue if tok.exact_type != token.NAME: raise SyntaxError(f"Expected name, found {tok}") name = tok.string tok = self.next() if tok.exact_type != token.EQUAL: raise SyntaxError(f"Expected =, found {tok}") tok = self.next() kwargs[name] = self._parse() tok = self.next() return kwargs def _parse(self): tok = self.current if tok.exact_type in (token.ENDMARKER, token.COMMA): return # TODO if tok.exact_type == token.STRING: self.next() return AstLiteral(tok.string[1:-1]) if tok.exact_type == token.NUMBER: self.next() try: value = int(tok.string) except ValueError: value = float(tok.string) return AstLiteral(value) if tok.exact_type == token.NAME: state = AstContext(tok.string) while True: tok = self.next() if tok.exact_type == token.DOT: tok = self.next() if tok.exact_type != token.NAME: raise SyntaxError(f"Invalid attr lookup: {tok}") state = AstAttr(state, tok.string) elif tok.exact_type == token.LSQB: self.next() right = self._parse() state = AstLookup(state, right) if self.current.exact_type != token.RSQB: raise SyntaxError(f"Expected ] but found {self.current}") elif tok.exact_type == token.LPAR: state = AstCall(state) self.next() while self.current.exact_type != token.RPAR: arg = self._parse() state.add_arg(arg) if self.current.exact_type != token.COMMA: break self.next() if self.current.exact_type != token.RPAR: raise SyntaxError(f"Expected ( but found {self.current}") self.next() else: break return state raise SyntaxError( f"Error parsing expression {tok.line !r}: Unexpected token {tok.string!r} at position {tok.start[0]}." ) class Node: name = None def __init__(self, content): self.content = content def render(self, context, output): pass # pragma: no cover class TextTag(Node): def render(self, context, output): output.write(self.content) class VarTag(Node): def __init__(self, content): self.expr = Expression.parse(content) def render(self, context, output): value = str(self.expr.resolve(context)) if not getattr(value, '__safe__', False): value = context.escape(value) output.write(value) class BlockNode(Node): __tags__ : Dict[str, 'BlockNode'] = {} child_nodelists: Iterable[str] = ("nodelist",) def __init_subclass__(cls, *, name): super().__init_subclass__() cls.name = name BlockNode.__tags__[name] = cls return cls @classmethod def parse(cls, content, parser): return cls(content) def nodes_by_type(self, node_type): for attr in self.child_nodelists: nodelist = getattr(self, attr, None) if nodelist: yield from nodelist.nodes_by_type(node_type) class ForTag(BlockNode, name="for"): child_nodelists = ("nodelist", "elselist") def __init__(self, argname, iterable, nodelist, elselist): self.argname, self.iterable, self.nodelist, self.elselist = argname, iterable, nodelist, elselist @classmethod def parse(cls, content, parser): argname, iterable = content.split(" in ", 1) nodelist = parser.parse_nodelist({"endfor", "else"}) elselist = parser.parse_nodelist({"endfor"}) if nodelist.endnode.name == "else" else None return cls(argname.strip(), Expression.parse(iterable.strip()), nodelist, elselist) def render(self, context, output): iterable = self.iterable.resolve(context) if iterable: with context.push(): for idx, item in enumerate(iterable): context.update({"loopcounter": idx, self.argname: item}) self.nodelist.render(context, output) elif self.elselist: self.elselist.render(context, output) class ElseTag(BlockNode, name="else"): pass class EndforTag(BlockNode, name="endfor"): pass class IfTag(BlockNode, name="if"): child_nodelists = ("nodelist", "elselist") def __init__(self, condition, nodelist, elselist): condition, inv = re.subn(r"^not\s+", "", condition, count=1) self.inv, self.condition = bool(inv), Expression.parse(condition) self.nodelist, self.elselist = nodelist, elselist @classmethod def parse(cls, content, parser): nodelist = parser.parse_nodelist({"endif", "else"}) elselist = parser.parse_nodelist({"endif"}) if nodelist.endnode.name == "else" else None return cls(content, nodelist, elselist) def render(self, context, output): if self.test_condition(context): self.nodelist.render(context, output) elif self.elselist: self.elselist.render(context, output) def test_condition(self, context): return self.inv ^ bool(self.condition.resolve(context)) class EndifTag(BlockNode, name="endif"): pass class IncludeTag(BlockNode, name="include"): def __init__(self, template_name, kwargs, loader): self.template_name, self.kwargs, self.loader = template_name, kwargs, loader @classmethod def parse(cls, content, parser): if parser.loader is None: raise RuntimeError("Can't use {% include %} without a bound Loader") tokens = Expression(content) template_name = tokens._parse() kwargs = tokens.parse_kwargs() return cls(template_name, kwargs, parser.loader) def render(self, context, output): name = self.template_name.resolve(context) tmpl = self.loader[name] kwargs = {key: expr.resolve(context) for key, expr in self.kwargs.items()} ctx = context.new_child(kwargs) tmpl.render(ctx, output) class LoadTag(BlockNode, name="load"): @classmethod def parse(cls, content, parser): importlib.import_module(content) return cls(None) class ExtendsTag(BlockNode, name="extends"): def __init__(self, parent, loader, nodelist): self.parent, self.loader, self.nodelist = parent, loader, nodelist @classmethod def parse(cls, content, parser): parent = Expression.parse(content) nodelist = parser.parse_nodelist([]) return cls(parent, parser.loader, nodelist) def render(self, context, output): parent = self.loader[self.parent.resolve(context)] block_context = getattr(context, "block_context", None) if block_context is None: block_context = context.block_context = defaultdict(deque) for block in self.nodelist.nodes_by_type(BlockTag): block_context[block.block_name].append(block) if parent.nodelist[0].name != "extends": for block in parent.nodelist.nodes_by_type(BlockTag): block_context[block.block_name].append(block) parent.render(context, output) class BlockTag(BlockNode, name="block"): def __init__(self, name, nodelist): self.block_name, self.nodelist = name, nodelist @classmethod def parse(cls, content, parser): m = re.match(r"\w+", content) if not m: raise ValueError(f'Invalid block label: {content !r}') name = m.group(0) nodelist = parser.parse_nodelist({"endblock"}) return cls(name, nodelist) def render(self, context, output): self.context = context self.output = output self._render() def _render(self): block_context = getattr(self.context, "block_context", None) if not block_context: block = self else: block = block_context[self.block_name].popleft() with self.context.push({"block": self}): block.nodelist.render(self.context, self.output) if block_context: block_context[self.block_name].appendleft(block) @property def super(self): self._render() return "" class EndBlockTag(BlockNode, name="endblock"): pass class WithTag(BlockNode, name="with"): def __init__(self, kwargs, nodelist): self.kwargs, self.nodelist = kwargs, nodelist @classmethod def parse(cls, content, parser): kwargs = Expression(content).parse_kwargs() nodelist = parser.parse_nodelist({"endwith"}) return cls(kwargs, nodelist) def render(self, context, output): kwargs = {key: value.resolve(context) for key, value in self.kwargs.items()} with context.push(kwargs): self.nodelist.render(context, output) class EndWithTag(BlockNode, name="endwith"): pass class CaseTag(BlockNode, name="case"): def __init__(self, term, nodelist): self.term, self.nodelist = term, nodelist @classmethod def parse(cls, content, parser): term = Expression.parse(content) nodelist = parser.parse_nodelist(["endcase"]) else_found = False for node in nodelist: if node.name not in {"when", "else"}: raise SyntaxError(f"Only 'when' and 'else' allowed as children of case. Found: {node}") if node.name == "else": if else_found: raise SyntaxError("Case tag can only have one else child") else_found = True nodelist.sort(key=lambda x: x.name, reverse=True) return cls(term, nodelist) def render(self, context, output): value = self.term.resolve(context) for node in self.nodelist: if node.name == "when": other = node.term.resolve(context) else: other = value if value == other: node.render(context, output) return class WhenTag(BlockNode, name="when"): def __init__(self, term, nodelist): self.term, self.nodelist = term, nodelist @classmethod def parse(cls, content, parser): term = Expression.parse(content) nodelist = parser.parse_nodelist() return cls(term, nodelist) def render(self, context, output): self.nodelist.render(context, output) class EndCaseTag(BlockNode, name="endcase"): pass