/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to you under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.calcite.sql.test;

import org.apache.calcite.avatica.util.Casing;
import org.apache.calcite.avatica.util.Quoting;
import org.apache.calcite.config.CalciteConnectionProperty;
import org.apache.calcite.config.Lex;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.runtime.Utilities;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlCollation;
import org.apache.calcite.sql.SqlIntervalLiteral;
import org.apache.calcite.sql.SqlLiteral;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlOperatorTable;
import org.apache.calcite.sql.SqlSelect;
import org.apache.calcite.sql.SqlUnresolvedFunction;
import org.apache.calcite.sql.dialect.AnsiSqlDialect;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.calcite.sql.parser.SqlParser;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.parser.SqlParserUtil;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.util.SqlShuttle;
import org.apache.calcite.sql.validate.SqlConformance;
import org.apache.calcite.sql.validate.SqlConformanceEnum;
import org.apache.calcite.sql.validate.SqlMonotonicity;
import org.apache.calcite.sql.validate.SqlValidator;
import org.apache.calcite.sql.validate.SqlValidatorNamespace;
import org.apache.calcite.sql.validate.SqlValidatorScope;
import org.apache.calcite.sql.validate.SqlValidatorUtil;
import org.apache.calcite.test.CalciteAssert;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.TestUtil;
import org.apache.calcite.util.Util;

import com.google.common.collect.ImmutableList;

import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.function.UnaryOperator;

import static org.apache.calcite.sql.SqlUtil.stripAs;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.fail;

/**
 * Abstract implementation of
 * {@link org.apache.calcite.test.SqlValidatorTestCase.Tester}
 * that talks to a mock catalog.
 *
 * <p>This is to implement the default behavior: testing is only against the
 * {@link SqlValidator}.
 */
public abstract class AbstractSqlTester implements SqlTester, AutoCloseable {
  protected final SqlTestFactory factory;
  protected final UnaryOperator<SqlValidator> validatorTransform;

  public AbstractSqlTester(SqlTestFactory factory,
      UnaryOperator<SqlValidator> validatorTransform) {
    this.factory = Objects.requireNonNull(factory);
    this.validatorTransform = Objects.requireNonNull(validatorTransform);
  }

  public final SqlTestFactory getFactory() {
    return factory;
  }

  /**
   * {@inheritDoc}
   *
   * <p>This default implementation does nothing.
   */
  public void close() {
    // no resources to release
  }

  public final SqlConformance getConformance() {
    return (SqlConformance) factory.get("conformance");
  }

  public final SqlValidator getValidator() {
    return factory.getValidator();
  }

  public void assertExceptionIsThrown(String sql, String expectedMsgPattern) {
    final SqlValidator validator;
    final SqlNode sqlNode;
    final SqlParserUtil.StringAndPos sap = SqlParserUtil.findPos(sql);
    try {
      sqlNode = parseQuery(sap.sql);
      validator = getValidator();
    } catch (Throwable e) {
      checkParseEx(e, expectedMsgPattern, sap.sql);
      return;
    }

    Throwable thrown = null;
    try {
      validator.validate(sqlNode);
    } catch (Throwable ex) {
      thrown = ex;
    }

    SqlTests.checkEx(thrown, expectedMsgPattern, sap, SqlTests.Stage.VALIDATE);
  }

  protected void checkParseEx(Throwable e, String expectedMsgPattern, String sql) {
    try {
      throw e;
    } catch (SqlParseException spe) {
      String errMessage = spe.getMessage();
      if (expectedMsgPattern == null) {
        throw new RuntimeException("Error while parsing query:" + sql, spe);
      } else if (errMessage == null
          || !errMessage.matches(expectedMsgPattern)) {
        throw new RuntimeException("Error did not match expected ["
            + expectedMsgPattern + "] while parsing query ["
            + sql + "]", spe);
      }
    } catch (Throwable t) {
      throw new RuntimeException("Error while parsing query: " + sql, t);
    }
  }

  public RelDataType getColumnType(String sql) {
    RelDataType rowType = getResultType(sql);
    final List<RelDataTypeField> fields = rowType.getFieldList();
    assertEquals(1, fields.size(), "expected query to return 1 field");
    return fields.get(0).getType();
  }

  public RelDataType getResultType(String sql) {
    SqlValidator validator = getValidator();
    SqlNode n = parseAndValidate(validator, sql);

    return validator.getValidatedNodeType(n);
  }

  public SqlNode parseAndValidate(SqlValidator validator, String sql) {
    SqlNode sqlNode;
    try {
      sqlNode = parseQuery(sql);
    } catch (Throwable e) {
      throw new RuntimeException("Error while parsing query: " + sql, e);
    }
    return validator.validate(sqlNode);
  }

  public SqlNode parseQuery(String sql) throws SqlParseException {
    SqlParser parser = factory.createParser(sql);
    return parser.parseQuery();
  }

  public void checkColumnType(String sql, String expected) {
    RelDataType actualType = getColumnType(sql);
    String actual = SqlTests.getTypeString(actualType);
    assertEquals(expected, actual);
  }

  public void checkFieldOrigin(String sql, String fieldOriginList) {
    SqlValidator validator = getValidator();
    SqlNode n = parseAndValidate(validator, sql);
    final List<List<String>> list = validator.getFieldOrigins(n);
    final StringBuilder buf = new StringBuilder("{");
    int i = 0;
    for (List<String> strings : list) {
      if (i++ > 0) {
        buf.append(", ");
      }
      if (strings == null) {
        buf.append("null");
      } else {
        int j = 0;
        for (String s : strings) {
          if (j++ > 0) {
            buf.append('.');
          }
          buf.append(s);
        }
      }
    }
    buf.append("}");
    assertEquals(fieldOriginList, buf.toString());
  }

  public void checkResultType(String sql, String expected) {
    RelDataType actualType = getResultType(sql);
    String actual = SqlTests.getTypeString(actualType);
    assertEquals(expected, actual);
  }

  public void checkIntervalConv(String sql, String expected) {
    SqlValidator validator = getValidator();
    final SqlCall n = (SqlCall) parseAndValidate(validator, sql);

    SqlNode node = null;
    for (int i = 0; i < n.operandCount(); i++) {
      node = stripAs(n.operand(i));
      if (node instanceof SqlCall) {
        node = ((SqlCall) node).operand(0);
        break;
      }
    }

    assertNotNull(node);
    SqlIntervalLiteral intervalLiteral = (SqlIntervalLiteral) node;
    SqlIntervalLiteral.IntervalValue interval =
        (SqlIntervalLiteral.IntervalValue) intervalLiteral.getValue();
    long l =
        interval.getIntervalQualifier().isYearMonth()
            ? SqlParserUtil.intervalToMonths(interval)
            : SqlParserUtil.intervalToMillis(interval);
    String actual = l + "";
    assertEquals(expected, actual);
  }

  public void checkType(String expression, String type) {
    for (String sql : buildQueries(expression)) {
      checkColumnType(sql, type);
    }
  }

  public void checkCollation(
      String expression,
      String expectedCollationName,
      SqlCollation.Coercibility expectedCoercibility) {
    for (String sql : buildQueries(expression)) {
      RelDataType actualType = getColumnType(sql);
      SqlCollation collation = actualType.getCollation();

      assertEquals(
          expectedCollationName, collation.getCollationName());
      assertEquals(expectedCoercibility, collation.getCoercibility());
    }
  }

  public void checkCharset(
      String expression,
      Charset expectedCharset) {
    for (String sql : buildQueries(expression)) {
      RelDataType actualType = getColumnType(sql);
      Charset actualCharset = actualType.getCharset();

      if (!expectedCharset.equals(actualCharset)) {
        fail("\n"
            + "Expected=" + expectedCharset.name() + "\n"
            + "  actual=" + actualCharset.name());
      }
    }
  }

  public SqlTester withQuoting(Quoting quoting) {
    return with("quoting", quoting);
  }

  public SqlTester withQuotedCasing(Casing casing) {
    return with("quotedCasing", casing);
  }

  public SqlTester withUnquotedCasing(Casing casing) {
    return with("unquotedCasing", casing);
  }

  public SqlTester withCaseSensitive(boolean sensitive) {
    return with("caseSensitive", sensitive);
  }

  public SqlTester withLenientOperatorLookup(boolean lenient) {
    return with("lenientOperatorLookup", lenient);
  }

  public SqlTester withLex(Lex lex) {
    return withQuoting(lex.quoting)
        .withCaseSensitive(lex.caseSensitive)
        .withQuotedCasing(lex.quotedCasing)
        .withUnquotedCasing(lex.unquotedCasing);
  }

  public SqlTester withConformance(SqlConformance conformance) {
    if (conformance == null) {
      conformance = SqlConformanceEnum.DEFAULT;
    }
    final SqlTester tester = with("conformance", conformance);
    if (conformance instanceof SqlConformanceEnum) {
      return tester
          .withConnectionFactory(
              CalciteAssert.EMPTY_CONNECTION_FACTORY
                  .with(CalciteConnectionProperty.CONFORMANCE, conformance));
    } else {
      return tester;
    }
  }

  public SqlTester enableTypeCoercion(boolean enabled) {
    return with("enableTypeCoercion", enabled);
  }

  public SqlTester withOperatorTable(SqlOperatorTable operatorTable) {
    return with("operatorTable", operatorTable);
  }

  public SqlTester withConnectionFactory(
      CalciteAssert.ConnectionFactory connectionFactory) {
    return with("connectionFactory", connectionFactory);
  }

  protected final SqlTester with(final String name, final Object value) {
    return with(factory.with(name, value));
  }

  protected abstract SqlTester with(SqlTestFactory factory);

  // SqlTester methods

  public void setFor(
      SqlOperator operator,
      VmName... unimplementedVmNames) {
    // do nothing
  }

  public void checkAgg(
      String expr,
      String[] inputValues,
      Object result,
      double delta) {
    String query =
        SqlTests.generateAggQuery(expr, inputValues);
    check(query, SqlTests.ANY_TYPE_CHECKER, result, delta);
  }

  public void checkAggWithMultipleArgs(
      String expr,
      String[][] inputValues,
      Object result,
      double delta) {
    String query =
        SqlTests.generateAggQueryWithMultipleArgs(expr, inputValues);
    check(query, SqlTests.ANY_TYPE_CHECKER, result, delta);
  }

  public void checkWinAgg(
      String expr,
      String[] inputValues,
      String windowSpec,
      String type,
      Object result,
      double delta) {
    String query =
        SqlTests.generateWinAggQuery(
            expr, windowSpec, inputValues);
    check(query, SqlTests.ANY_TYPE_CHECKER, result, delta);
  }

  public void checkScalar(
      String expression,
      Object result,
      String resultType) {
    checkType(expression, resultType);
    for (String sql : buildQueries(expression)) {
      check(sql, SqlTests.ANY_TYPE_CHECKER, result, 0);
    }
  }

  public void checkScalarExact(
      String expression,
      String result) {
    for (String sql : buildQueries(expression)) {
      check(sql, SqlTests.INTEGER_TYPE_CHECKER, result, 0);
    }
  }

  public void checkScalarExact(
      String expression,
      String expectedType,
      String result) {
    for (String sql : buildQueries(expression)) {
      TypeChecker typeChecker =
          new SqlTests.StringTypeChecker(expectedType);
      check(sql, typeChecker, result, 0);
    }
  }

  public void checkScalarApprox(
      String expression,
      String expectedType,
      double expectedResult,
      double delta) {
    for (String sql : buildQueries(expression)) {
      TypeChecker typeChecker =
          new SqlTests.StringTypeChecker(expectedType);
      check(sql, typeChecker, expectedResult, delta);
    }
  }

  public void checkBoolean(
      String expression,
      Boolean result) {
    for (String sql : buildQueries(expression)) {
      if (null == result) {
        checkNull(expression);
      } else {
        check(
            sql,
            SqlTests.BOOLEAN_TYPE_CHECKER,
            result.toString(),
            0);
      }
    }
  }

  public void checkString(
      String expression,
      String result,
      String expectedType) {
    for (String sql : buildQueries(expression)) {
      TypeChecker typeChecker =
          new SqlTests.StringTypeChecker(expectedType);
      check(sql, typeChecker, result, 0);
    }
  }

  public void checkNull(String expression) {
    for (String sql : buildQueries(expression)) {
      check(sql, SqlTests.ANY_TYPE_CHECKER, null, 0);
    }
  }

  public final void check(
      String query,
      TypeChecker typeChecker,
      Object result,
      double delta) {
    check(query, typeChecker, SqlTests.ANY_PARAMETER_CHECKER,
        SqlTests.createChecker(result, delta));
  }

  public void check(String query, TypeChecker typeChecker,
      ParameterChecker parameterChecker, ResultChecker resultChecker) {
    // This implementation does NOT check the result!
    // All it does is check the return type.

    if (typeChecker == null) {
      // Parse and validate. There should be no errors.
      Util.discard(getResultType(query));
    } else {
      // Parse and validate. There should be no errors.
      // There must be 1 column. Get its type.
      RelDataType actualType = getColumnType(query);

      // Check result type.
      typeChecker.checkType(actualType);
    }

    SqlValidator validator = getValidator();
    SqlNode n = parseAndValidate(validator, query);
    final RelDataType parameterRowType = validator.getParameterRowType(n);
    parameterChecker.checkParameters(parameterRowType);
  }

  public void checkMonotonic(String query,
      SqlMonotonicity expectedMonotonicity) {
    SqlValidator validator = getValidator();
    SqlNode n = parseAndValidate(validator, query);
    final RelDataType rowType = validator.getValidatedNodeType(n);
    final SqlValidatorNamespace selectNamespace = validator.getNamespace(n);
    final String field0 = rowType.getFieldList().get(0).getName();
    final SqlMonotonicity monotonicity =
        selectNamespace.getMonotonicity(field0);
    assertThat(monotonicity, equalTo(expectedMonotonicity));
  }

  public void checkRewrite(String query, String expectedRewrite) {
    final SqlValidator validator = validatorTransform.apply(getValidator());
    SqlNode rewrittenNode = parseAndValidate(validator, query);
    String actualRewrite =
        rewrittenNode.toSqlString(AnsiSqlDialect.DEFAULT, false).getSql();
    TestUtil.assertEqualsVerbose(expectedRewrite, Util.toLinux(actualRewrite));
  }

  public void checkFails(
      String expression,
      String expectedError,
      boolean runtime) {
    if (runtime) {
      // We need to test that the expression fails at runtime.
      // Ironically, that means that it must succeed at prepare time.
      SqlValidator validator = getValidator();
      final String sql = buildQuery(expression);
      SqlNode n = parseAndValidate(validator, sql);
      assertNotNull(n);
    } else {
      checkQueryFails(buildQuery(expression), expectedError);
    }
  }

  public void checkQueryFails(String sql, String expectedError) {
    assertExceptionIsThrown(sql, expectedError);
  }

  public void checkQuery(String sql) {
    assertExceptionIsThrown(sql, null);
  }

  public SqlMonotonicity getMonotonicity(String sql) {
    final SqlValidator validator = getValidator();
    final SqlNode node = parseAndValidate(validator, sql);
    final SqlSelect select = (SqlSelect) node;
    final SqlNode selectItem0 = select.getSelectList().get(0);
    final SqlValidatorScope scope = validator.getSelectScope(select);
    return selectItem0.getMonotonicity(scope);
  }

  public static String buildQuery(String expression) {
    return "values (" + expression + ")";
  }

  public static String buildQueryAgg(String expression) {
    return "select " + expression + " from (values (1)) as t(x) group by x";
  }

  /**
   * Builds a query that extracts all literals as columns in an underlying
   * select.
   *
   * <p>For example,</p>
   *
   * <blockquote>{@code 1 < 5}</blockquote>
   *
   * <p>becomes</p>
   *
   * <blockquote>{@code SELECT p0 < p1
   * FROM (VALUES (1, 5)) AS t(p0, p1)}</blockquote>
   *
   * <p>Null literals don't have enough type information to be extracted.
   * We push down {@code CAST(NULL AS type)} but raw nulls such as
   * {@code CASE 1 WHEN 2 THEN 'a' ELSE NULL END} are left as is.</p>
   *
   * @param expression Scalar expression
   * @return Query that evaluates a scalar expression
   */
  protected String buildQuery2(String expression) {
    // "values (1 < 5)"
    // becomes
    // "select p0 < p1 from (values (1, 5)) as t(p0, p1)"
    SqlNode x;
    final String sql = "values (" + expression + ")";
    try {
      x = parseQuery(sql);
    } catch (SqlParseException e) {
      throw TestUtil.rethrow(e);
    }
    final Collection<SqlNode> literalSet = new LinkedHashSet<>();
    x.accept(
        new SqlShuttle() {
          private final List<SqlOperator> ops =
              ImmutableList.of(
                  SqlStdOperatorTable.LITERAL_CHAIN,
                  SqlStdOperatorTable.LOCALTIME,
                  SqlStdOperatorTable.LOCALTIMESTAMP,
                  SqlStdOperatorTable.CURRENT_TIME,
                  SqlStdOperatorTable.CURRENT_TIMESTAMP);

          @Override public SqlNode visit(SqlLiteral literal) {
            if (!isNull(literal)
                && literal.getTypeName() != SqlTypeName.SYMBOL) {
              literalSet.add(literal);
            }
            return literal;
          }

          @Override public SqlNode visit(SqlCall call) {
            SqlOperator operator = call.getOperator();
            if (operator instanceof SqlUnresolvedFunction) {
              final SqlUnresolvedFunction unresolvedFunction = (SqlUnresolvedFunction) operator;
              final SqlOperator lookup = SqlValidatorUtil.lookupSqlFunctionByID(
                  SqlStdOperatorTable.instance(),
                  unresolvedFunction.getSqlIdentifier(),
                  unresolvedFunction.getFunctionType());
              if (lookup != null) {
                operator = lookup;
                final SqlNode[] operands = call.getOperandList().toArray(SqlNode.EMPTY_ARRAY);
                call = operator.createCall(
                    call.getFunctionQuantifier(),
                    call.getParserPosition(),
                    operands);
              }
            }
            if (operator == SqlStdOperatorTable.CAST
                && isNull(call.operand(0))) {
              literalSet.add(call);
              return call;
            } else if (ops.contains(operator)) {
              // "Argument to function 'LOCALTIME' must be a
              // literal"
              return call;
            } else {
              return super.visit(call);
            }
          }

          private boolean isNull(SqlNode sqlNode) {
            return sqlNode instanceof SqlLiteral
                && ((SqlLiteral) sqlNode).getTypeName()
                == SqlTypeName.NULL;
          }
        });
    final List<SqlNode> nodes = new ArrayList<>(literalSet);
    nodes.sort((o1, o2) -> {
      final SqlParserPos pos0 = o1.getParserPosition();
      final SqlParserPos pos1 = o2.getParserPosition();
      int c = -Utilities.compare(pos0.getLineNum(), pos1.getLineNum());
      if (c != 0) {
        return c;
      }
      return -Utilities.compare(pos0.getColumnNum(), pos1.getColumnNum());
    });
    String sql2 = sql;
    final List<Pair<String, String>> values = new ArrayList<>();
    int p = 0;
    for (SqlNode literal : nodes) {
      final SqlParserPos pos = literal.getParserPosition();
      final int start =
          SqlParserUtil.lineColToIndex(
              sql, pos.getLineNum(), pos.getColumnNum());
      final int end =
          SqlParserUtil.lineColToIndex(
              sql,
              pos.getEndLineNum(),
              pos.getEndColumnNum()) + 1;
      String param = "p" + (p++);
      values.add(Pair.of(sql2.substring(start, end), param));
      sql2 = sql2.substring(0, start)
          + param
          + sql2.substring(end);
    }
    if (values.isEmpty()) {
      values.add(Pair.of("1", "p0"));
    }
    return "select "
        + sql2.substring("values (".length(), sql2.length() - 1)
        + " from (values ("
        + Util.commaList(Pair.left(values))
        + ")) as t("
        + Util.commaList(Pair.right(values))
        + ")";
  }

  /**
   * Converts a scalar expression into a list of SQL queries that
   * evaluate it.
   *
   * @param expression Scalar expression
   * @return List of queries that evaluate an expression
   */
  private Iterable<String> buildQueries(final String expression) {
    // Why an explicit iterable rather than a list? If there is
    // a syntax error in the expression, the calling code discovers it
    // before we try to parse it to do substitutions on the parse tree.
    return () -> new Iterator<String>() {
      int i = 0;

      public void remove() {
        throw new UnsupportedOperationException();
      }

      public String next() {
        switch (i++) {
        case 0:
          return buildQuery(expression);
        case 1:
          return buildQuery2(expression);
        default:
          throw new NoSuchElementException();
        }
      }

      public boolean hasNext() {
        return i < 2;
      }
    };
  }

}