* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements.  See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership.  The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License.  You may obtain a copy of the License at
*   http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* KIND, either express or implied.  See the License for the
* specific language governing permissions and limitations
* under the License.

package org.apache.samza.sql.planner;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import java.lang.reflect.Method;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.sql.SqlCallBinding;
import org.apache.calcite.sql.SqlOperandCountRange;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.type.SqlOperandCountRanges;
import org.apache.calcite.sql.type.SqlOperandTypeChecker;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.samza.SamzaException;
import org.apache.samza.sql.interfaces.UdfMetadata;
import org.apache.samza.sql.schema.SamzaSqlFieldType;
import org.apache.samza.sql.udfs.SamzaSqlUdfMethod;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

class Checker implements SqlOperandTypeChecker {

  private static final Logger LOG = LoggerFactory.getLogger(Checker.class);
  private static final List<SqlTypeName> ANY_SQL_TYPE_NAMES = ImmutableList.of(SqlTypeName.ANY, SqlTypeName.OTHER);
  static final Checker ANY_CHECKER = new Checker();

  private final Optional<UdfMetadata> udfMetadataOptional;
  private final SqlOperandCountRange range;

  public static Checker getChecker(int min, int max, UdfMetadata udfMetadata) {
    if (min == max) {
      return new Checker(min, udfMetadata);
    } else {
      return new Checker(min, max, udfMetadata);

  private Checker(int size, UdfMetadata udfMetadata) {
    this.range = SqlOperandCountRanges.of(size);
    this.udfMetadataOptional = Optional.of(udfMetadata);

  private Checker(int min, int max, UdfMetadata udfMetadata) {
    this.range = SqlOperandCountRanges.between(min, max);
    this.udfMetadataOptional = Optional.of(udfMetadata);

  private Checker() {
    this.range = SqlOperandCountRanges.any();
    this.udfMetadataOptional = Optional.empty();

  public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) {
    if (!udfMetadataOptional.isPresent() || udfMetadataOptional.get().isDisableArgCheck() || !throwOnFailure) {
      return true;
    } else {
      // 1. Generate a mapping from argument index to parsed calcite-type for the sql UDF.
      Map<Integer, RelDataType> argumentIndexToCalciteType = IntStream.range(0, callBinding.getOperandCount())
          .collect(Collectors.toMap(operandIndex -> operandIndex, callBinding::getOperandType, (a, b) -> b));

      UdfMetadata udfMetadata = udfMetadataOptional.get();
      List<SamzaSqlFieldType> udfArguments = udfMetadata.getArguments();

      // 2. Compare the argument type in samza-sql UDF against the RelType generated by the
      // calcite parser engine.
      for (int udfArgumentIndex = 0; udfArgumentIndex < udfArguments.size(); ++udfArgumentIndex) {
        SamzaSqlFieldType udfArgumentType = udfArguments.get(udfArgumentIndex);
        SqlTypeName udfArgumentAsSqlType = toCalciteSqlType(udfArgumentType);
        RelDataType parsedSqlArgType = argumentIndexToCalciteType.get(udfArgumentIndex);

        // 3(a). Special-case, where static strings used as method-arguments in udf-methods during invocation are parsed as the Char type by calcite.
        if (parsedSqlArgType.getSqlTypeName() == SqlTypeName.CHAR && udfArgumentAsSqlType == SqlTypeName.VARCHAR) {
          return true;
        } else if (!Objects.equals(parsedSqlArgType.getSqlTypeName(), udfArgumentAsSqlType)
                && !ANY_SQL_TYPE_NAMES.contains(parsedSqlArgType.getSqlTypeName()) && hasOneUdfMethod(udfMetadata)) {
          // 3(b). Throw up and fail on mismatch between the SamzaSqlType and CalciteType for any argument.
          String msg = String.format("Type mismatch in udf class: %s at argument index: %d." +
                          "Expected type: %s, actual type: %s.", udfMetadata.getName(),
                  udfArgumentIndex, parsedSqlArgType.getSqlTypeName(), udfArgumentAsSqlType);
          throw new SamzaSqlValidatorException(msg);
    // 4. The SamzaSqlFieldType and CalciteType has matched for all the arguments in the UDF.
    return true;

   * Checks if there is only one UdfMethod in the input {@link UdfMetadata}.
   * @param udfMetadata the metadata for a UDF.
   * @return true if there is only one udf method defined in the UdfMetadata.
   *         false otherwise.
  boolean hasOneUdfMethod(UdfMetadata udfMetadata) {
    Class<?> udfClass = udfMetadata.getUdfMethod().getDeclaringClass();
    int numAnnotatedUdfMethods = 0;
    for (Method method : udfClass.getMethods()) {
      if (method.isAnnotationPresent(SamzaSqlUdfMethod.class)) {
        numAnnotatedUdfMethods += 1;
    return numAnnotatedUdfMethods == 1;

  public SqlOperandCountRange getOperandCountRange() {
    return range;

  public String getAllowedSignatures(SqlOperator op, String opName) {
    return opName + "(Drill - Opaque)";

  public Consistency getConsistency() {
    return Consistency.NONE;

  public boolean isOptional(int i) {
    return false;

   * Converts the {@link SamzaSqlFieldType} to the calcite {@link SqlTypeName}.
   * @param samzaSqlFieldType the samza sql field type.
   * @return the converted calcite SqlTypeName.
  static SqlTypeName toCalciteSqlType(SamzaSqlFieldType samzaSqlFieldType) {
    switch (samzaSqlFieldType) {
      case ANY:
      case ROW:
        return SqlTypeName.ANY;
      case MAP:
        return SqlTypeName.MAP;
      case ARRAY:
        return SqlTypeName.ARRAY;
      case REAL:
        return SqlTypeName.REAL;
      case DOUBLE:
        return SqlTypeName.DOUBLE;
      case STRING:
        return SqlTypeName.VARCHAR;
      case INT16:
      case INT32:
        return SqlTypeName.INTEGER;
      case FLOAT:
        return SqlTypeName.FLOAT;
      case INT64:
        return SqlTypeName.BIGINT;
      case BOOLEAN:
        return SqlTypeName.BOOLEAN;
      case BYTES:
        return SqlTypeName.VARBINARY;
        String msg = String.format("Field Type %s is not supported", samzaSqlFieldType);
        throw new SamzaException(msg);