/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.samza.sql.planner; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import java.lang.reflect.Method; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.stream.Collectors; import java.util.stream.IntStream; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.sql.SqlCallBinding; import org.apache.calcite.sql.SqlOperandCountRange; import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.type.SqlOperandCountRanges; import org.apache.calcite.sql.type.SqlOperandTypeChecker; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.samza.SamzaException; import org.apache.samza.sql.interfaces.UdfMetadata; import org.apache.samza.sql.schema.SamzaSqlFieldType; import org.apache.samza.sql.udfs.SamzaSqlUdfMethod; import org.slf4j.Logger; import org.slf4j.LoggerFactory; class Checker implements SqlOperandTypeChecker { private static final Logger LOG = LoggerFactory.getLogger(Checker.class); private static final List<SqlTypeName> ANY_SQL_TYPE_NAMES = ImmutableList.of(SqlTypeName.ANY, SqlTypeName.OTHER); static final Checker ANY_CHECKER = new Checker(); private final Optional<UdfMetadata> udfMetadataOptional; private final SqlOperandCountRange range; public static Checker getChecker(int min, int max, UdfMetadata udfMetadata) { if (min == max) { return new Checker(min, udfMetadata); } else { return new Checker(min, max, udfMetadata); } } private Checker(int size, UdfMetadata udfMetadata) { this.range = SqlOperandCountRanges.of(size); this.udfMetadataOptional = Optional.of(udfMetadata); } private Checker(int min, int max, UdfMetadata udfMetadata) { this.range = SqlOperandCountRanges.between(min, max); this.udfMetadataOptional = Optional.of(udfMetadata); } private Checker() { this.range = SqlOperandCountRanges.any(); this.udfMetadataOptional = Optional.empty(); } @Override public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) { if (!udfMetadataOptional.isPresent() || udfMetadataOptional.get().isDisableArgCheck() || !throwOnFailure) { return true; } else { // 1. Generate a mapping from argument index to parsed calcite-type for the sql UDF. Map<Integer, RelDataType> argumentIndexToCalciteType = IntStream.range(0, callBinding.getOperandCount()) .boxed() .collect(Collectors.toMap(operandIndex -> operandIndex, callBinding::getOperandType, (a, b) -> b)); UdfMetadata udfMetadata = udfMetadataOptional.get(); List<SamzaSqlFieldType> udfArguments = udfMetadata.getArguments(); // 2. Compare the argument type in samza-sql UDF against the RelType generated by the // calcite parser engine. for (int udfArgumentIndex = 0; udfArgumentIndex < udfArguments.size(); ++udfArgumentIndex) { SamzaSqlFieldType udfArgumentType = udfArguments.get(udfArgumentIndex); SqlTypeName udfArgumentAsSqlType = toCalciteSqlType(udfArgumentType); RelDataType parsedSqlArgType = argumentIndexToCalciteType.get(udfArgumentIndex); // 3(a). Special-case, where static strings used as method-arguments in udf-methods during invocation are parsed as the Char type by calcite. if (parsedSqlArgType.getSqlTypeName() == SqlTypeName.CHAR && udfArgumentAsSqlType == SqlTypeName.VARCHAR) { return true; } else if (!Objects.equals(parsedSqlArgType.getSqlTypeName(), udfArgumentAsSqlType) && !ANY_SQL_TYPE_NAMES.contains(parsedSqlArgType.getSqlTypeName()) && hasOneUdfMethod(udfMetadata)) { // 3(b). Throw up and fail on mismatch between the SamzaSqlType and CalciteType for any argument. String msg = String.format("Type mismatch in udf class: %s at argument index: %d." + "Expected type: %s, actual type: %s.", udfMetadata.getName(), udfArgumentIndex, parsedSqlArgType.getSqlTypeName(), udfArgumentAsSqlType); LOG.error(msg); throw new SamzaSqlValidatorException(msg); } } } // 4. The SamzaSqlFieldType and CalciteType has matched for all the arguments in the UDF. return true; } /** * Checks if there is only one UdfMethod in the input {@link UdfMetadata}. * @param udfMetadata the metadata for a UDF. * @return true if there is only one udf method defined in the UdfMetadata. * false otherwise. */ @VisibleForTesting boolean hasOneUdfMethod(UdfMetadata udfMetadata) { Class<?> udfClass = udfMetadata.getUdfMethod().getDeclaringClass(); int numAnnotatedUdfMethods = 0; for (Method method : udfClass.getMethods()) { if (method.isAnnotationPresent(SamzaSqlUdfMethod.class)) { numAnnotatedUdfMethods += 1; } } return numAnnotatedUdfMethods == 1; } @Override public SqlOperandCountRange getOperandCountRange() { return range; } @Override public String getAllowedSignatures(SqlOperator op, String opName) { return opName + "(Drill - Opaque)"; } @Override public Consistency getConsistency() { return Consistency.NONE; } @Override public boolean isOptional(int i) { return false; } /** * Converts the {@link SamzaSqlFieldType} to the calcite {@link SqlTypeName}. * @param samzaSqlFieldType the samza sql field type. * @return the converted calcite SqlTypeName. */ @VisibleForTesting static SqlTypeName toCalciteSqlType(SamzaSqlFieldType samzaSqlFieldType) { switch (samzaSqlFieldType) { case ANY: case ROW: return SqlTypeName.ANY; case MAP: return SqlTypeName.MAP; case ARRAY: return SqlTypeName.ARRAY; case REAL: return SqlTypeName.REAL; case DOUBLE: return SqlTypeName.DOUBLE; case STRING: return SqlTypeName.VARCHAR; case INT16: case INT32: return SqlTypeName.INTEGER; case FLOAT: return SqlTypeName.FLOAT; case INT64: return SqlTypeName.BIGINT; case BOOLEAN: return SqlTypeName.BOOLEAN; case BYTES: return SqlTypeName.VARBINARY; default: String msg = String.format("Field Type %s is not supported", samzaSqlFieldType); LOG.error(msg); throw new SamzaException(msg); } } }