/******************************************************************************* * Copyright (c) 2013, Salesforce.com, Inc. * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * Neither the name of Salesforce.com nor the names of its contributors may * be used to endorse or promote products derived from this software without * specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ******************************************************************************/ package com.salesforce.phoenix.parse; import java.lang.reflect.Constructor; import java.sql.SQLException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; import org.apache.hadoop.hbase.filter.CompareFilter.CompareOp; import org.apache.hadoop.hbase.util.Pair; import com.google.common.collect.ListMultimap; import com.google.common.collect.Maps; import com.salesforce.phoenix.exception.UnknownFunctionException; import com.salesforce.phoenix.expression.Expression; import com.salesforce.phoenix.expression.ExpressionType; import com.salesforce.phoenix.expression.function.AvgAggregateFunction; import com.salesforce.phoenix.expression.function.CountAggregateFunction; import com.salesforce.phoenix.expression.function.CurrentDateFunction; import com.salesforce.phoenix.expression.function.CurrentTimeFunction; import com.salesforce.phoenix.expression.function.DistinctCountAggregateFunction; import com.salesforce.phoenix.expression.function.FunctionExpression; import com.salesforce.phoenix.parse.FunctionParseNode.BuiltInFunction; import com.salesforce.phoenix.parse.FunctionParseNode.BuiltInFunctionInfo; import com.salesforce.phoenix.parse.JoinTableNode.JoinType; import com.salesforce.phoenix.schema.ColumnModifier; import com.salesforce.phoenix.schema.PDataType; import com.salesforce.phoenix.schema.PIndexState; import com.salesforce.phoenix.schema.PTableType; import com.salesforce.phoenix.schema.TypeMismatchException; import com.salesforce.phoenix.util.SchemaUtil; /** * * Factory used by parser to construct object model while parsing a SQL statement * * @author jtaylor * @since 0.1 */ public class ParseNodeFactory { private static final String ARRAY_ELEM = "ARRAY_ELEM"; // TODO: Use Google's Reflection library instead to find aggregate functions @SuppressWarnings("unchecked") private static final List<Class<? extends FunctionExpression>> CLIENT_SIDE_BUILT_IN_FUNCTIONS = Arrays.<Class<? extends FunctionExpression>>asList( CurrentDateFunction.class, CurrentTimeFunction.class, AvgAggregateFunction.class ); private static final Map<BuiltInFunctionKey, BuiltInFunctionInfo> BUILT_IN_FUNCTION_MAP = Maps.newHashMap(); /** * * Key used to look up a built-in function using the combination of * the lowercase name and the number of arguments. This disambiguates * the aggregate MAX(<col>) from the non aggregate MAX(<col1>,<col2>). * * @author jtaylor * @since 0.1 */ private static class BuiltInFunctionKey { private final String upperName; private final int argCount; private BuiltInFunctionKey(String lowerName, int argCount) { this.upperName = lowerName; this.argCount = argCount; } @Override public String toString() { return upperName; } @Override public int hashCode() { final int prime = 31; int result = 1; result = prime * result + argCount; result = prime * result + ((upperName == null) ? 0 : upperName.hashCode()); return result; } @Override public boolean equals(Object obj) { if (this == obj) return true; if (obj == null) return false; if (getClass() != obj.getClass()) return false; BuiltInFunctionKey other = (BuiltInFunctionKey)obj; if (argCount != other.argCount) return false; if (!upperName.equals(other.upperName)) return false; return true; } } private static void addBuiltInFunction(Class<? extends FunctionExpression> f) throws Exception { BuiltInFunction d = f.getAnnotation(BuiltInFunction.class); if (d == null) { return; } int nArgs = d.args().length; BuiltInFunctionInfo value = new BuiltInFunctionInfo(f, d); do { // Add function to function map, throwing if conflicts found // Add entry for each possible version of function based on arguments that are not required to be present (i.e. arg with default value) BuiltInFunctionKey key = new BuiltInFunctionKey(value.getName(), nArgs); if (BUILT_IN_FUNCTION_MAP.put(key, value) != null) { throw new IllegalStateException("Multiple " + value.getName() + " functions with " + nArgs + " arguments"); } } while (--nArgs >= 0 && d.args()[nArgs].defaultValue().length() > 0); // Look for default values that aren't at the end and throw while (--nArgs >= 0) { if (d.args()[nArgs].defaultValue().length() > 0) { throw new IllegalStateException("Function " + value.getName() + " has non trailing default value of '" + d.args()[nArgs].defaultValue() + "'. Only trailing arguments may have default values"); } } } /** * Reflect this class and populate static structures from it. * Don't initialize in static block because we have a circular dependency */ private synchronized static void initBuiltInFunctionMap() { if (!BUILT_IN_FUNCTION_MAP.isEmpty()) { return; } Class<? extends FunctionExpression> f = null; try { // Reflection based parsing which yields direct explicit function evaluation at runtime for (int i = 0; i < CLIENT_SIDE_BUILT_IN_FUNCTIONS.size(); i++) { f = CLIENT_SIDE_BUILT_IN_FUNCTIONS.get(i); addBuiltInFunction(f); } for (ExpressionType et : ExpressionType.values()) { Class<? extends Expression> ec = et.getExpressionClass(); if (FunctionExpression.class.isAssignableFrom(ec)) { @SuppressWarnings("unchecked") Class<? extends FunctionExpression> c = (Class<? extends FunctionExpression>)ec; addBuiltInFunction(f = c); } } } catch (Exception e) { throw new RuntimeException("Failed initialization of built-in functions at class '" + f + "'", e); } } private static BuiltInFunctionInfo getInfo(String name, List<ParseNode> children) { return get(SchemaUtil.normalizeIdentifier(name), children); } public static BuiltInFunctionInfo get(String normalizedName, List<ParseNode> children) { initBuiltInFunctionMap(); BuiltInFunctionInfo info = BUILT_IN_FUNCTION_MAP.get(new BuiltInFunctionKey(normalizedName,children.size())); if (info == null) { throw new UnknownFunctionException(normalizedName); } return info; } public ParseNodeFactory() { } public ExplainStatement explain(BindableStatement statement) { return new ExplainStatement(statement); } public AliasedNode aliasedNode(String alias, ParseNode expression) { return new AliasedNode(alias, expression); } public AddParseNode add(List<ParseNode> children) { return new AddParseNode(children); } public SubtractParseNode subtract(List<ParseNode> children) { return new SubtractParseNode(children); } public MultiplyParseNode multiply(List<ParseNode> children) { return new MultiplyParseNode(children); } public AndParseNode and(List<ParseNode> children) { return new AndParseNode(children); } public FamilyWildcardParseNode family(String familyName){ return new FamilyWildcardParseNode(familyName, false); } public WildcardParseNode wildcard() { return WildcardParseNode.INSTANCE; } public BetweenParseNode between(ParseNode l, ParseNode r1, ParseNode r2, boolean negate) { return new BetweenParseNode(l, r1, r2, negate); } public BindParseNode bind(String bind) { return new BindParseNode(bind); } public StringConcatParseNode concat(List<ParseNode> children) { return new StringConcatParseNode(children); } public ColumnParseNode column(TableName tableName, String name, String alias) { return new ColumnParseNode(tableName,name,alias); } public ColumnName columnName(String columnName) { return new ColumnName(columnName); } public ColumnName columnName(String familyName, String columnName) { return new ColumnName(familyName, columnName); } public PropertyName propertyName(String propertyName) { return new PropertyName(propertyName); } public PropertyName propertyName(String familyName, String propertyName) { return new PropertyName(familyName, propertyName); } public ColumnDef columnDef(ColumnName columnDefName, String sqlTypeName, boolean isNull, Integer maxLength, Integer scale, boolean isPK, ColumnModifier columnModifier) { return new ColumnDef(columnDefName, sqlTypeName, isNull, maxLength, scale, isPK, columnModifier); } public ColumnDef columnDef(ColumnName columnDefName, String sqlTypeName, boolean isArray, Integer arrSize, boolean isNull, Integer maxLength, Integer scale, boolean isPK, ColumnModifier columnModifier) { return new ColumnDef(columnDefName, sqlTypeName, isArray, arrSize, isNull, maxLength, scale, isPK, columnModifier); } public PrimaryKeyConstraint primaryKey(String name, List<Pair<ColumnName, ColumnModifier>> columnNameAndModifier) { return new PrimaryKeyConstraint(name, columnNameAndModifier); } public CreateTableStatement createTable(TableName tableName, ListMultimap<String,Pair<String,Object>> props, List<ColumnDef> columns, PrimaryKeyConstraint pkConstraint, List<ParseNode> splits, PTableType tableType, boolean ifNotExists, TableName baseTableName, ParseNode tableTypeIdNode, int bindCount) { return new CreateTableStatement(tableName, props, columns, pkConstraint, splits, tableType, ifNotExists, baseTableName, tableTypeIdNode, bindCount); } public CreateIndexStatement createIndex(NamedNode indexName, NamedTableNode dataTable, PrimaryKeyConstraint pkConstraint, List<ColumnName> includeColumns, List<ParseNode> splits, ListMultimap<String,Pair<String,Object>> props, boolean ifNotExists, int bindCount) { return new CreateIndexStatement(indexName, dataTable, pkConstraint, includeColumns, splits, props, ifNotExists, bindCount); } public CreateSequenceStatement createSequence(TableName tableName, ParseNode startsWith, ParseNode incrementBy, ParseNode cacheSize, boolean ifNotExits, int bindCount){ return new CreateSequenceStatement(tableName, startsWith, incrementBy, cacheSize, ifNotExits, bindCount); } public DropSequenceStatement dropSequence(TableName tableName, boolean ifExits, int bindCount){ return new DropSequenceStatement(tableName, ifExits, bindCount); } public SequenceValueParseNode currentValueFor(TableName tableName) { return new SequenceValueParseNode(tableName, SequenceValueParseNode.Op.CURRENT_VALUE); } public SequenceValueParseNode nextValueFor(TableName tableName) { return new SequenceValueParseNode(tableName, SequenceValueParseNode.Op.NEXT_VALUE); } public AddColumnStatement addColumn(NamedTableNode table, PTableType tableType, List<ColumnDef> columnDefs, boolean ifNotExists, Map<String,Object> props) { return new AddColumnStatement(table, tableType, columnDefs, ifNotExists, props); } public DropColumnStatement dropColumn(NamedTableNode table, PTableType tableType, List<ColumnName> columnNodes, boolean ifExists) { return new DropColumnStatement(table, tableType, columnNodes, ifExists); } public DropTableStatement dropTable(TableName tableName, PTableType tableType, boolean ifExists) { return new DropTableStatement(tableName, tableType, ifExists); } public DropIndexStatement dropIndex(NamedNode indexName, TableName tableName, boolean ifExists) { return new DropIndexStatement(indexName, tableName, ifExists); } public AlterIndexStatement alterIndex(NamedTableNode indexTableNode, String dataTableName, boolean ifExists, PIndexState state) { return new AlterIndexStatement(indexTableNode, dataTableName, ifExists, state); } public TableName table(String schemaName, String tableName) { return TableName.createNormalized(schemaName,tableName); } public NamedNode indexName(String name) { return new NamedNode(name); } public NamedTableNode namedTable(String alias, TableName name) { return new NamedTableNode(alias, name); } public NamedTableNode namedTable(String alias, TableName name ,List<ColumnDef> dyn_columns) { return new NamedTableNode(alias, name,dyn_columns); } public BindTableNode bindTable(String alias, TableName name) { return new BindTableNode(alias, name); } public CaseParseNode caseWhen(List<ParseNode> children) { return new CaseParseNode(children); } public DivideParseNode divide(List<ParseNode> children) { return new DivideParseNode(children); } public FunctionParseNode functionDistinct(String name, List<ParseNode> args) { if (CountAggregateFunction.NAME.equals(SchemaUtil.normalizeIdentifier(name))) { BuiltInFunctionInfo info = getInfo( SchemaUtil.normalizeIdentifier(DistinctCountAggregateFunction.NAME), args); return new DistinctCountParseNode(name, args, info); } else { throw new UnsupportedOperationException("DISTINCT not supported with " + name); } } public FunctionParseNode arrayElemRef(List<ParseNode> args) { return function(ARRAY_ELEM, args); } public FunctionParseNode function(String name, List<ParseNode> args) { BuiltInFunctionInfo info = getInfo(name, args); Constructor<? extends FunctionParseNode> ctor = info.getNodeCtor(); if (ctor == null) { return info.isAggregate() ? new AggregateFunctionParseNode(name, args, info) : new FunctionParseNode(name, args, info); } else { try { return ctor.newInstance(name, args, info); } catch (Exception e) { throw new RuntimeException(e); } } } public FunctionParseNode function(String name, List<ParseNode> valueNodes, List<ParseNode> columnNodes, boolean isAscending) { // Right now we support PERCENT functions on only one column if (valueNodes.size() != 1 || columnNodes.size() != 1) { throw new UnsupportedOperationException(name + " not supported on multiple columns"); } List<ParseNode> children = new ArrayList<ParseNode>(3); children.add(columnNodes.get(0)); children.add(new LiteralParseNode(Boolean.valueOf(isAscending))); children.add(valueNodes.get(0)); return function(name, children); } public HintNode hint(String hint) { return new HintNode(hint); } public InListParseNode inList(List<ParseNode> children, boolean negate) { return new InListParseNode(children, negate); } public ExistsParseNode exists(ParseNode l, ParseNode r, boolean negate) { return new ExistsParseNode(l, r, negate); } public InParseNode in(ParseNode l, ParseNode r, boolean negate) { return new InParseNode(l, r, negate); } public IsNullParseNode isNull(ParseNode child, boolean negate) { return new IsNullParseNode(child, negate); } public JoinTableNode join (JoinType type, ParseNode on, TableNode table) { return new JoinTableNode(type, on, table); } public DerivedTableNode derivedTable (String alias, SelectStatement select) { return new DerivedTableNode(alias, select); } public LikeParseNode like(ParseNode lhs, ParseNode rhs, boolean negate) { return new LikeParseNode(lhs, rhs, negate); } public LiteralParseNode literal(Object value) { return new LiteralParseNode(value); } public CastParseNode cast(ParseNode expression, String dataType) { return new CastParseNode(expression, dataType); } public CastParseNode cast(ParseNode expression, PDataType dataType) { return new CastParseNode(expression, dataType); } public ParseNode rowValueConstructor(List<ParseNode> l) { return new RowValueConstructorParseNode(l); } private void checkTypeMatch (PDataType expectedType, PDataType actualType) throws SQLException { if (!expectedType.isCoercibleTo(actualType)) { throw TypeMismatchException.newException(expectedType, actualType); } } public LiteralParseNode literal(Object value, PDataType expectedType) throws SQLException { PDataType actualType = PDataType.fromLiteral(value); if (actualType != null && actualType != expectedType) { checkTypeMatch(expectedType, actualType); value = expectedType.toObject(value, actualType); } return new LiteralParseNode(value); } public LiteralParseNode coerce(LiteralParseNode literalNode, PDataType expectedType) throws SQLException { PDataType actualType = literalNode.getType(); if (actualType != null) { Object before = literalNode.getValue(); checkTypeMatch(expectedType, actualType); Object after = expectedType.toObject(before, actualType); if (before != after) { literalNode = literal(after); } } return literalNode; } public ComparisonParseNode comparison(CompareOp op, ParseNode lhs, ParseNode rhs) { switch (op){ case LESS: return lt(lhs,rhs); case LESS_OR_EQUAL: return lte(lhs,rhs); case EQUAL: return equal(lhs,rhs); case NOT_EQUAL: return notEqual(lhs,rhs); case GREATER_OR_EQUAL: return gte(lhs,rhs); case GREATER: return gt(lhs,rhs); default: throw new IllegalArgumentException("Unexpcted CompareOp of " + op); } } public GreaterThanParseNode gt(ParseNode lhs, ParseNode rhs) { return new GreaterThanParseNode(lhs, rhs); } public GreaterThanOrEqualParseNode gte(ParseNode lhs, ParseNode rhs) { return new GreaterThanOrEqualParseNode(lhs, rhs); } public LessThanParseNode lt(ParseNode lhs, ParseNode rhs) { return new LessThanParseNode(lhs, rhs); } public LessThanOrEqualParseNode lte(ParseNode lhs, ParseNode rhs) { return new LessThanOrEqualParseNode(lhs, rhs); } public EqualParseNode equal(ParseNode lhs, ParseNode rhs) { return new EqualParseNode(lhs, rhs); } public ArrayConstructorNode upsertStmtArrayNode(List<ParseNode> upsertStmtArray) { return new ArrayConstructorNode(upsertStmtArray); } public MultiplyParseNode negate(ParseNode child) { return new MultiplyParseNode(Arrays.asList(child,this.literal(-1))); } public NotEqualParseNode notEqual(ParseNode lhs, ParseNode rhs) { return new NotEqualParseNode(lhs, rhs); } public NotParseNode not(ParseNode child) { return new NotParseNode(child); } public OrParseNode or(List<ParseNode> children) { return new OrParseNode(children); } public OrderByNode orderBy(ParseNode expression, boolean nullsLast, boolean orderAscending) { return new OrderByNode(expression, nullsLast, orderAscending); } public OuterJoinParseNode outer(ParseNode node) { return new OuterJoinParseNode(node); } public SelectStatement select(List<? extends TableNode> from, HintNode hint, boolean isDistinct, List<AliasedNode> select, ParseNode where, List<ParseNode> groupBy, ParseNode having, List<OrderByNode> orderBy, LimitNode limit, int bindCount, boolean isAggregate) { return new SelectStatement(from, hint, isDistinct, select, where, groupBy == null ? Collections.<ParseNode>emptyList() : groupBy, having, orderBy == null ? Collections.<OrderByNode>emptyList() : orderBy, limit, bindCount, isAggregate); } public UpsertStatement upsert(NamedTableNode table, HintNode hint, List<ColumnName> columns, List<ParseNode> values, SelectStatement select, int bindCount) { return new UpsertStatement(table, hint, columns, values, select, bindCount); } public DeleteStatement delete(NamedTableNode table, HintNode hint, ParseNode node, List<OrderByNode> orderBy, LimitNode limit, int bindCount) { return new DeleteStatement(table, hint, node, orderBy, limit, bindCount); } public SelectStatement select(SelectStatement statement, ParseNode where, ParseNode having) { return select(statement.getFrom(), statement.getHint(), statement.isDistinct(), statement.getSelect(), where, statement.getGroupBy(), having, statement.getOrderBy(), statement.getLimit(), statement.getBindCount(), statement.isAggregate()); } public SelectStatement select(SelectStatement statement, List<? extends TableNode> tables) { return select(tables, statement.getHint(), statement.isDistinct(), statement.getSelect(), statement.getWhere(), statement.getGroupBy(), statement.getHaving(), statement.getOrderBy(), statement.getLimit(), statement.getBindCount(), statement.isAggregate()); } public SelectStatement select(SelectStatement statement, HintNode hint) { return hint == null || hint.isEmpty() ? statement : select(statement.getFrom(), hint, statement.isDistinct(), statement.getSelect(), statement.getWhere(), statement.getGroupBy(), statement.getHaving(), statement.getOrderBy(), statement.getLimit(), statement.getBindCount(), statement.isAggregate()); } public SubqueryParseNode subquery(SelectStatement select) { return new SubqueryParseNode(select); } public LimitNode limit(BindParseNode b) { return new LimitNode(b); } public LimitNode limit(LiteralParseNode l) { return new LimitNode(l); } }