package com.hubspot.jinjava.lib.tag;

import static org.assertj.core.api.Assertions.assertThat;

import com.google.common.collect.ImmutableList;
import com.hubspot.jinjava.Jinjava;
import com.hubspot.jinjava.JinjavaConfig;
import com.hubspot.jinjava.interpret.Context;
import com.hubspot.jinjava.interpret.JinjavaInterpreter;
import com.hubspot.jinjava.lib.filter.Filter;
import com.hubspot.jinjava.lib.fn.ELFunctionDefinition;
import com.hubspot.jinjava.lib.fn.MacroFunction;
import com.hubspot.jinjava.tree.Node;
import com.hubspot.jinjava.tree.TextNode;
import com.hubspot.jinjava.tree.parse.DefaultTokenScannerSymbols;
import com.hubspot.jinjava.tree.parse.TextToken;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.runners.MockitoJUnitRunner;

@RunWith(MockitoJUnitRunner.class)
public class ValidationModeTest {
  JinjavaInterpreter interpreter;
  JinjavaInterpreter validatingInterpreter;

  Jinjava jinjava;
  private Context context;

  ValidationFilter validationFilter;

  class ValidationFilter implements Filter {
    private int executionCount = 0;

    @Override
    public Object filter(Object var, JinjavaInterpreter interpreter, String... args) {
      executionCount++;
      return var;
    }

    public int getExecutionCount() {
      return executionCount;
    }

    @Override
    public String getName() {
      return "validation_filter";
    }
  }

  private static int functionExecutionCount = 0;

  public static int validationTestFunction() {
    return ++functionExecutionCount;
  }

  @Before
  public void setup() {
    validationFilter = new ValidationFilter();

    ELFunctionDefinition validationFunction = new ELFunctionDefinition(
      "",
      "validation_test",
      ValidationModeTest.class,
      "validationTestFunction"
    );

    jinjava = new Jinjava();
    jinjava.getGlobalContext().registerFilter(validationFilter);
    jinjava.getGlobalContext().registerFunction(validationFunction);
    interpreter = jinjava.newInterpreter();
    context = interpreter.getContext();

    validatingInterpreter =
      new JinjavaInterpreter(
        jinjava,
        context,
        JinjavaConfig.newBuilder().withValidationMode(true).build()
      );

    JinjavaInterpreter.pushCurrent(interpreter);
  }

  @After
  public void tearDown() {
    JinjavaInterpreter.popCurrent();
  }

  @Test
  public void itResolvesAllIfExpressionsInValidationMode() {
    validatingInterpreter.render(
      "{{ badCode( }}" + "{% if false %}" + "  {{ badCode( }}" + "{% endif %}"
    );

    assertThat(validatingInterpreter.getErrors().size()).isEqualTo(2);
  }

  @Test
  public void itResolvesAllUnlessExpressionsInValidationMode() {
    validatingInterpreter.render(
      "{{ badCode( }}" + "{% unless false %}" + "  {{ badCode( }}" + "{% endunless %}"
    );

    assertThat(validatingInterpreter.getErrors().size()).isEqualTo(2);
  }

  @Test
  public void itResolvesAllForExpressionsInValidationMode() {
    validatingInterpreter.render(
      "{{ badCode( }}" + "{% for i in [1, 2, 3] %}" + "  {{ badCode( }}" + "{% endfor %}"
    );

    assertThat(validatingInterpreter.getErrors().size()).isEqualTo(4);
  }

  @Test
  public void itResolvesNestedForExpressionsInValidationMode() {
    String output = validatingInterpreter.render(
      "{{ badCode( }}" +
      "{% for i in [] %}" +
      "  outer loop" +
      "  {% for i in [1, 2, 3] %}" +
      "    inner loop  {{ badCode( }}" +
      "  {% endfor %}" +
      "{% endfor %}" +
      "done"
    );

    assertThat(validatingInterpreter.getErrors().size()).isEqualTo(4);
    assertThat(output.trim()).isEqualTo("done");
  }

  @Test
  public void itResolvesZeroLoopForExpressionsInValidationMode() {
    String output = validatingInterpreter.render(
      "{{ badCode( }}" +
      "{% for i in [] %}" +
      "in loop {{ badCode( }}" +
      "{% endfor %}" +
      "hi"
    );

    assertThat(validatingInterpreter.getErrors().size()).isEqualTo(2);
    assertThat(output.trim()).isEqualTo("hi");
  }

  @Test
  public void itAllowsPropertyReferenceInForLoopInValidationMode() {
    String output = validatingInterpreter.render(
      "{% for i in [] %}" + "{{ i.test }}" + "{% endfor %}" + "hi"
    );

    assertThat(validatingInterpreter.getErrors().size()).isEqualTo(0);
    assertThat(output.trim()).isEqualTo("hi");
  }

  @Test
  public void itAllowsPropertyReferenceAndTypeCoercionInForLoopInValidationMode() {
    String output = validatingInterpreter.render(
      "{% for i in [] %}" +
      "{{ i.test + 100 }}" +
      "{{ i.nope ~ 'hello' }}" +
      "{% endfor %}" +
      "hi"
    );

    assertThat(validatingInterpreter.getErrors().size()).isEqualTo(0);
    assertThat(output.trim()).isEqualTo("hi");
  }

  @Test
  public void itResolvesZeroLoopTupleForExpressionsInValidationMode() {
    String output = validatingInterpreter.render(
      "{{ badCode( }}" +
      "{% set map = {} %}" +
      "{% for a, b in map.items() %}" +
      "in loop {{ badCode( }}" +
      "{% endfor %}" +
      "hi"
    );

    assertThat(validatingInterpreter.getErrors().size()).isEqualTo(2);
    assertThat(output.trim()).isEqualTo("hi");
  }

  @Test
  public void itDoesNotSetValuesInValidatedBlocks() {
    String output = validatingInterpreter.render(
      "{% set foo = \"orig value\" %}" +
      "{% if false %}" +
      "  {% set foo = \"in false block\" %}" +
      "{% endif %}" +
      "{{ foo }}"
    );

    assertThat(output.trim()).isEqualTo("orig value");
    assertThat(validatingInterpreter.getErrors()).isEmpty();
  }

  @Test
  public void itDoesNotSetValuesInNestedValidatedBlocks() {
    String output = validatingInterpreter.render(
      "{% set foo = \"orig value\" %}" +
      "{% if false %}" +
      " {% if true %}" +
      "  {% set foo = \"in nested block\" %}" +
      " {% endif %}" +
      "{% endif %}" +
      "{{ foo }}"
    );

    assertThat(output.trim()).isEqualTo("orig value");
    assertThat(validatingInterpreter.getErrors()).isEmpty();
  }

  @Test
  public void itDoesNotPrintValuesInNestedValidatedBlocks() {
    String output = validatingInterpreter.render(
      "hi " +
      "{% if false %}" +
      " hidey " +
      "  {% if true %}" +
      "    hey" +
      "  {% endif %}" +
      "{% endif %}" +
      "there"
    );

    assertThat(output.trim()).isEqualTo("hi there");
    assertThat(validatingInterpreter.getErrors()).isEmpty();
  }

  private class InstrumentedMacroFunction extends MacroFunction {
    private int invocationCount = 0;

    InstrumentedMacroFunction(
      List<Node> content,
      String name,
      LinkedHashMap<String, Object> argNamesWithDefaults,
      boolean caller,
      Context localContextScope
    ) {
      super(content, name, argNamesWithDefaults, caller, localContextScope, -1, -1);
    }

    @Override
    public Object doEvaluate(
      Map<String, Object> argMap,
      Map<String, Object> kwargMap,
      List<Object> varArgs
    ) {
      invocationCount++;
      return super.doEvaluate(argMap, kwargMap, varArgs);
    }

    int getInvocationCount() {
      return invocationCount;
    }
  }

  @Test
  public void itDoesNotExecuteMacrosInValidatedBlocks() {
    TextNode textNode = new TextNode(
      new TextToken("hello", 1, 1, new DefaultTokenScannerSymbols())
    );
    InstrumentedMacroFunction macro = new InstrumentedMacroFunction(
      ImmutableList.of(textNode),
      "hello",
      new LinkedHashMap<>(),
      false,
      interpreter.getContext()
    );
    interpreter.getContext().addGlobalMacro(macro);

    String template =
      "{{ hello() }}" + "{% if false %} " + "  {{ hello() }}" + "{% endif %}";

    assertThat(interpreter.getErrors()).isEmpty();
    assertThat(interpreter.render(template).trim()).isEqualTo("hello");
    assertThat(macro.getInvocationCount()).isEqualTo(1);

    assertThat(validatingInterpreter.render(template).trim()).isEqualTo("hello");
    assertThat(macro.getInvocationCount()).isEqualTo(3);
    assertThat(validatingInterpreter.getErrors()).isEmpty();
  }

  @Test
  public void itDoesNotExecuteFunctionsInValidatedBlocks() {
    functionExecutionCount = 0;

    assertThat(functionExecutionCount).isEqualTo(0);

    String template =
      "{{ validation_test() }}" +
      "{% if false %}" +
      "  {{ validation_test() }}" +
      "  {{ hey( }}" +
      "{% endif %}";

    String result = interpreter.render(template);
    assertThat(interpreter.getErrors()).isEmpty();
    assertThat(result).isEqualTo("1");
    assertThat(functionExecutionCount).isEqualTo(1);

    result = validatingInterpreter.render(template);

    assertThat(validatingInterpreter.getErrors().size()).isEqualTo(1);
    assertThat(validatingInterpreter.getErrors().get(0).getMessage()).contains("hey(");
    assertThat(result).isEqualTo("2");
    assertThat(functionExecutionCount).isEqualTo(2);
  }

  @Test
  public void itDoesNotExecuteFiltersInValidatedBlocks() {
    assertThat(validationFilter.getExecutionCount()).isEqualTo(0);

    String template =
      "{{ 10|validation_filter() }}" +
      "{% if false %}" +
      "  {{ 10|validation_filter() }}" +
      "  {{ hey( }}" +
      "{% endif %}";

    String result = interpreter.render(template).trim();
    assertThat(interpreter.getErrors()).isEmpty();
    assertThat(result).isEqualTo("10");
    assertThat(validationFilter.getExecutionCount()).isEqualTo(1);

    JinjavaInterpreter.pushCurrent(validatingInterpreter);
    result = validatingInterpreter.render(template).trim();

    assertThat(validatingInterpreter.getErrors().size()).isEqualTo(1);
    assertThat(validatingInterpreter.getErrors().get(0).getMessage()).contains("hey(");
    assertThat(result).isEqualTo("10");
    assertThat(validationFilter.getExecutionCount()).isEqualTo(2);
  }
}