/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to you under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.calcite.test; import org.apache.calcite.adapter.enumerable.EnumerableConvention; import org.apache.calcite.adapter.enumerable.EnumerableHashJoin; import org.apache.calcite.adapter.enumerable.EnumerableRules; import org.apache.calcite.plan.Convention; import org.apache.calcite.plan.ConventionTraitDef; import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelOptPlanner; import org.apache.calcite.plan.RelOptRule; import org.apache.calcite.plan.RelOptRuleCall; import org.apache.calcite.plan.RelOptUtil; import org.apache.calcite.plan.RelTraitSet; import org.apache.calcite.plan.hep.HepPlanner; import org.apache.calcite.plan.hep.HepProgram; import org.apache.calcite.plan.hep.HepProgramBuilder; import org.apache.calcite.plan.volcano.AbstractConverter; import org.apache.calcite.plan.volcano.VolcanoPlanner; import org.apache.calcite.rel.RelCollationTraitDef; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.RelShuttleImpl; import org.apache.calcite.rel.RelVisitor; import org.apache.calcite.rel.convert.ConverterRule; import org.apache.calcite.rel.core.Aggregate; import org.apache.calcite.rel.core.Calc; import org.apache.calcite.rel.core.Join; import org.apache.calcite.rel.core.JoinInfo; import org.apache.calcite.rel.core.TableScan; import org.apache.calcite.rel.hint.HintPredicate; import org.apache.calcite.rel.hint.HintPredicates; import org.apache.calcite.rel.hint.HintStrategy; import org.apache.calcite.rel.hint.HintStrategyTable; import org.apache.calcite.rel.hint.Hintable; import org.apache.calcite.rel.hint.RelHint; import org.apache.calcite.rel.logical.LogicalAggregate; import org.apache.calcite.rel.logical.LogicalJoin; import org.apache.calcite.rel.logical.LogicalProject; import org.apache.calcite.rel.rules.AggregateReduceFunctionsRule; import org.apache.calcite.rel.rules.FilterMergeRule; import org.apache.calcite.rel.rules.FilterProjectTransposeRule; import org.apache.calcite.rel.rules.ProjectMergeRule; import org.apache.calcite.rel.rules.ProjectToCalcRule; import org.apache.calcite.sql.SqlDelete; import org.apache.calcite.sql.SqlInsert; import org.apache.calcite.sql.SqlMerge; import org.apache.calcite.sql.SqlNodeList; import org.apache.calcite.sql.SqlTableRef; import org.apache.calcite.sql.SqlUpdate; import org.apache.calcite.sql.SqlUtil; import org.apache.calcite.sql2rel.SqlToRelConverter; import org.apache.calcite.tools.Program; import org.apache.calcite.tools.Programs; import org.apache.calcite.tools.RuleSet; import org.apache.calcite.tools.RuleSets; import org.apache.calcite.util.Litmus; import org.apache.calcite.util.Util; import org.apache.log4j.AppenderSkeleton; import org.apache.log4j.Level; import org.apache.log4j.Logger; import org.apache.log4j.spi.LoggingEvent; import org.junit.jupiter.api.Test; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Locale; import java.util.function.UnaryOperator; import java.util.stream.Collectors; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.collection.IsIn.in; import static org.hamcrest.core.Is.is; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.fail; /** * Unit test for {@link org.apache.calcite.rel.hint.RelHint}. */ class SqlHintsConverterTest extends SqlToRelTestBase { protected DiffRepository getDiffRepos() { return DiffRepository.lookup(SqlHintsConverterTest.class); } //~ Tests ------------------------------------------------------------------ @Test void testQueryHint() { final String sql = HintTools.withHint("select /*+ %s */ *\n" + "from emp e1\n" + "inner join dept d1 on e1.deptno = d1.deptno\n" + "inner join emp e2 on e1.ename = e2.job"); sql(sql).ok(); } @Test void testQueryHintWithLiteralOptions() { final String sql = "select /*+ time_zone(1, 1.23, 'a bc', -1.0) */ *\n" + "from emp"; sql(sql).ok(); } @Test void testNestedQueryHint() { final String sql = "select /*+ resource(parallelism='3'), repartition(10) */ empno\n" + "from (select /*+ resource(mem='20Mb')*/ empno, ename from emp)"; sql(sql).ok(); } @Test void testTwoLevelNestedQueryHint() { final String sql = "select /*+ resource(parallelism='3'), no_hash_join */ empno\n" + "from (select /*+ resource(mem='20Mb')*/ empno, ename\n" + "from emp left join dept on emp.deptno = dept.deptno)"; sql(sql).ok(); } @Test void testThreeLevelNestedQueryHint() { final String sql = "select /*+ index(idx1), no_hash_join */ * from emp /*+ index(empno) */\n" + "e1 join dept/*+ index(deptno) */ d1 on e1.deptno = d1.deptno\n" + "join emp e2 on d1.name = e2.job"; sql(sql).ok(); } @Test void testFourLevelNestedQueryHint() { final String sql = "select /*+ index(idx1), no_hash_join */ * from emp /*+ index(empno) */\n" + "e1 join dept/*+ index(deptno) */ d1 on e1.deptno = d1.deptno join\n" + "(select max(sal) as sal from emp /*+ index(empno) */) e2 on e1.sal = e2.sal"; sql(sql).ok(); } @Test void testAggregateHints() { final String sql = "select /*+ AGG_STRATEGY(TWO_PHASE), RESOURCE(mem='1024') */\n" + "count(deptno), avg_sal from (\n" + "select /*+ AGG_STRATEGY(ONE_PHASE) */ avg(sal) as avg_sal, deptno\n" + "from emp group by deptno) group by avg_sal"; sql(sql).ok(); } @Test void testHintsInSubQueryWithDecorrelation() { final String sql = "select /*+ resource(parallelism='3'), AGG_STRATEGY(TWO_PHASE) */\n" + "sum(e1.empno) from emp e1, dept d1\n" + "where e1.deptno = d1.deptno\n" + "and e1.sal> (\n" + "select /*+ resource(cpu='2') */ avg(e2.sal) from emp e2 where e2.deptno = d1.deptno)"; sql(sql).withTester(t -> t.withDecorrelation(true)).ok(); } @Test void testHintsInSubQueryWithDecorrelation2() { final String sql = "select /*+ properties(k1='v1', k2='v2'), index(ename), no_hash_join */\n" + "sum(e1.empno) from emp e1, dept d1\n" + "where e1.deptno = d1.deptno\n" + "and e1.sal> (\n" + "select /*+ properties(k1='v1', k2='v2'), index(ename), no_hash_join */\n" + " avg(e2.sal)\n" + " from emp e2\n" + " where e2.deptno = d1.deptno)"; sql(sql).withTester(t -> t.withDecorrelation(true)).ok(); } @Test void testHintsInSubQueryWithDecorrelation3() { final String sql = "select /*+ resource(parallelism='3'), index(ename), no_hash_join */\n" + "sum(e1.empno) from emp e1, dept d1\n" + "where e1.deptno = d1.deptno\n" + "and e1.sal> (\n" + "select /*+ resource(cpu='2'), index(ename), no_hash_join */\n" + " avg(e2.sal)\n" + " from emp e2\n" + " where e2.deptno = d1.deptno)"; sql(sql).withTester(t -> t.withDecorrelation(true)).ok(); } @Test void testHintsInSubQueryWithoutDecorrelation() { final String sql = "select /*+ resource(parallelism='3') */\n" + "sum(e1.empno) from emp e1, dept d1\n" + "where e1.deptno = d1.deptno\n" + "and e1.sal> (\n" + "select /*+ resource(cpu='2') */ avg(e2.sal) from emp e2 where e2.deptno = d1.deptno)"; sql(sql).ok(); } @Test void testInvalidQueryHint() { final String sql = "select /*+ weird_hint */ empno\n" + "from (select /*+ resource(mem='20Mb')*/ empno, ename\n" + "from emp left join dept on emp.deptno = dept.deptno)"; sql(sql).warns("Hint: WEIRD_HINT should be registered in the HintStrategyTable"); final String sql1 = "select /*+ resource(mem='20Mb')*/ empno\n" + "from (select /*+ weird_kv_hint(k1='v1') */ empno, ename\n" + "from emp left join dept on emp.deptno = dept.deptno)"; sql(sql1).warns("Hint: WEIRD_KV_HINT should be registered in the HintStrategyTable"); final String sql2 = "select /*+ AGG_STRATEGY(OPTION1) */\n" + "ename, avg(sal)\n" + "from emp group by ename"; final String error2 = "Hint AGG_STRATEGY only allows single option, " + "allowed options: [ONE_PHASE, TWO_PHASE]"; sql(sql2).warns(error2); // Change the error handler to validate again. sql(sql2).withTester( tester -> tester.withConfig( SqlToRelConverter.configBuilder() .withHintStrategyTable( HintTools.createHintStrategies( HintStrategyTable.builder().errorHandler(Litmus.THROW))) .build())) .fails(error2); } @Test void testTableHintsInJoin() { final String sql = "select\n" + "ename, job, sal, dept.name\n" + "from emp /*+ index(idx1, idx2) */\n" + "join dept /*+ properties(k1='v1', k2='v2') */\n" + "on emp.deptno = dept.deptno"; sql(sql).ok(); } @Test void testTableHintsInSelect() { final String sql = HintTools.withHint("select * from emp /*+ %s */"); sql(sql).ok(); } @Test void testSameHintsWithDifferentInheritPath() { final String sql = "select /*+ properties(k1='v1', k2='v2') */\n" + "ename, job, sal, dept.name\n" + "from emp /*+ index(idx1, idx2) */\n" + "join dept /*+ properties(k1='v1', k2='v2') */\n" + "on emp.deptno = dept.deptno"; sql(sql).ok(); } @Test void testTableHintsInInsert() throws Exception { final String sql = HintTools.withHint("insert into dept /*+ %s */ (deptno, name) " + "select deptno, name from dept"); final SqlInsert insert = (SqlInsert) tester.parseQuery(sql); assert insert.getTargetTable() instanceof SqlTableRef; final SqlTableRef tableRef = (SqlTableRef) insert.getTargetTable(); List<RelHint> hints = SqlUtil.getRelHint(HintTools.HINT_STRATEGY_TABLE, (SqlNodeList) tableRef.getOperandList().get(1)); assertHintsEquals( Arrays.asList( HintTools.PROPS_HINT, HintTools.IDX_HINT, HintTools.JOIN_HINT), hints); } @Test void testTableHintsInUpdate() throws Exception { final String sql = HintTools.withHint("update emp /*+ %s */ " + "set name = 'test' where deptno = 1"); final SqlUpdate sqlUpdate = (SqlUpdate) tester.parseQuery(sql); assert sqlUpdate.getTargetTable() instanceof SqlTableRef; final SqlTableRef tableRef = (SqlTableRef) sqlUpdate.getTargetTable(); List<RelHint> hints = SqlUtil.getRelHint(HintTools.HINT_STRATEGY_TABLE, (SqlNodeList) tableRef.getOperandList().get(1)); assertHintsEquals( Arrays.asList( HintTools.PROPS_HINT, HintTools.IDX_HINT, HintTools.JOIN_HINT), hints); } @Test void testTableHintsInDelete() throws Exception { final String sql = HintTools.withHint("delete from emp /*+ %s */ where deptno = 1"); final SqlDelete sqlDelete = (SqlDelete) tester.parseQuery(sql); assert sqlDelete.getTargetTable() instanceof SqlTableRef; final SqlTableRef tableRef = (SqlTableRef) sqlDelete.getTargetTable(); List<RelHint> hints = SqlUtil.getRelHint(HintTools.HINT_STRATEGY_TABLE, (SqlNodeList) tableRef.getOperandList().get(1)); assertHintsEquals( Arrays.asList( HintTools.PROPS_HINT, HintTools.IDX_HINT, HintTools.JOIN_HINT), hints); } @Test void testTableHintsInMerge() throws Exception { final String sql = "merge into emps\n" + "/*+ %s */ e\n" + "using tempemps as t\n" + "on e.empno = t.empno\n" + "when matched then update\n" + "set name = t.name, deptno = t.deptno, salary = t.salary * .1\n" + "when not matched then insert (name, dept, salary)\n" + "values(t.name, 10, t.salary * .15)"; final String sql1 = HintTools.withHint(sql); final SqlMerge sqlMerge = (SqlMerge) tester.parseQuery(sql1); assert sqlMerge.getTargetTable() instanceof SqlTableRef; final SqlTableRef tableRef = (SqlTableRef) sqlMerge.getTargetTable(); List<RelHint> hints = SqlUtil.getRelHint(HintTools.HINT_STRATEGY_TABLE, (SqlNodeList) tableRef.getOperandList().get(1)); assertHintsEquals( Arrays.asList( HintTools.PROPS_HINT, HintTools.IDX_HINT, HintTools.JOIN_HINT), hints); } @Test void testInvalidTableHints() { final String sql = "select\n" + "ename, job, sal, dept.name\n" + "from emp /*+ weird_hint(idx1, idx2) */\n" + "join dept /*+ properties(k1='v1', k2='v2') */\n" + "on emp.deptno = dept.deptno"; sql(sql).warns("Hint: WEIRD_HINT should be registered in the HintStrategyTable"); final String sql1 = "select\n" + "ename, job, sal, dept.name\n" + "from emp /*+ index(idx1, idx2) */\n" + "join dept /*+ weird_kv_hint(k1='v1', k2='v2') */\n" + "on emp.deptno = dept.deptno"; sql(sql1).warns("Hint: WEIRD_KV_HINT should be registered in the HintStrategyTable"); } @Test void testJoinHintRequiresSpecificInputs() { final String sql = "select /*+ use_hash_join(r, s), use_hash_join(emp, dept) */\n" + "ename, job, sal, dept.name\n" + "from emp join dept on emp.deptno = dept.deptno"; // Hint use_hash_join(r, s) expect to be ignored by the join node. sql(sql).ok(); } @Test void testHintsForCalc() { final String sql = "select /*+ resource(mem='1024MB')*/ ename, sal, deptno from emp"; final RelNode rel = tester.convertSqlToRel(sql).rel; final RelHint hint = RelHint.builder("RESOURCE") .hintOption("MEM", "1024MB") .build(); // planner rule to convert Project to Calc. HepProgram program = new HepProgramBuilder() .addRuleInstance(ProjectToCalcRule.INSTANCE) .build(); HepPlanner planner = new HepPlanner(program); planner.setRoot(rel); RelNode newRel = planner.findBestExp(); new ValidateHintVisitor(hint, Calc.class).go(newRel); } @Test void testHintsPropagationInHepPlannerRules() { final String sql = "select /*+ use_hash_join(r, s), use_hash_join(emp, dept) */\n" + "ename, job, sal, dept.name\n" + "from emp join dept on emp.deptno = dept.deptno"; final RelNode rel = tester.convertSqlToRel(sql).rel; final RelHint hint = RelHint.builder("USE_HASH_JOIN") .inheritPath(0) .hintOption("EMP") .hintOption("DEPT") .build(); // Validate Hep planner. HepProgram program = new HepProgramBuilder() .addRuleInstance(MockJoinRule.INSTANCE) .build(); HepPlanner planner = new HepPlanner(program); planner.setRoot(rel); RelNode newRel = planner.findBestExp(); new ValidateHintVisitor(hint, Join.class).go(newRel); } @Test void testHintsPropagationInVolcanoPlannerRules() { final String sql = "select /*+ use_hash_join(r, s), use_hash_join(emp, dept) */\n" + "ename, job, sal, dept.name\n" + "from emp join dept on emp.deptno = dept.deptno"; RelOptPlanner planner = new VolcanoPlanner(); planner.addRelTraitDef(ConventionTraitDef.INSTANCE); Tester tester1 = tester.withDecorrelation(true) .withClusterFactory( relOptCluster -> RelOptCluster.create(planner, relOptCluster.getRexBuilder())); final RelNode rel = tester1.convertSqlToRel(sql).rel; final RelHint hint = RelHint.builder("USE_HASH_JOIN") .inheritPath(0) .hintOption("EMP") .hintOption("DEPT") .build(); // Validate Volcano planner. RuleSet ruleSet = RuleSets.ofList( new MockEnumerableJoinRule(hint), // Rule to validate the hint. FilterProjectTransposeRule.INSTANCE, FilterMergeRule.INSTANCE, ProjectMergeRule.INSTANCE, EnumerableRules.ENUMERABLE_JOIN_RULE, EnumerableRules.ENUMERABLE_PROJECT_RULE, EnumerableRules.ENUMERABLE_FILTER_RULE, EnumerableRules.ENUMERABLE_SORT_RULE, EnumerableRules.ENUMERABLE_LIMIT_RULE, EnumerableRules.ENUMERABLE_TABLE_SCAN_RULE); Program program = Programs.of(ruleSet); RelTraitSet toTraits = rel .getCluster() .traitSet() .replace(EnumerableConvention.INSTANCE); program.run(planner, rel, toTraits, Collections.emptyList(), Collections.emptyList()); } @Test void testHintsPropagateWithDifferentKindOfRels() { final String sql = "select /*+ AGG_STRATEGY(TWO_PHASE) */\n" + "ename, avg(sal)\n" + "from emp group by ename"; final RelNode rel = tester.convertSqlToRel(sql).rel; final RelHint hint = RelHint.builder("AGG_STRATEGY") .inheritPath(0) .hintOption("TWO_PHASE") .build(); // AggregateReduceFunctionsRule does the transformation: // AGG -> PROJECT + AGG HepProgram program = new HepProgramBuilder() .addRuleInstance(AggregateReduceFunctionsRule.INSTANCE) .build(); HepPlanner planner = new HepPlanner(program); planner.setRoot(rel); RelNode newRel = planner.findBestExp(); new ValidateHintVisitor(hint, Aggregate.class).go(newRel); } @Test void testUseMergeJoin() { final String sql = "select /*+ use_merge_join(emp, dept) */\n" + "ename, job, sal, dept.name\n" + "from emp join dept on emp.deptno = dept.deptno"; RelOptPlanner planner = new VolcanoPlanner(); planner.addRelTraitDef(ConventionTraitDef.INSTANCE); planner.addRelTraitDef(RelCollationTraitDef.INSTANCE); Tester tester1 = tester.withDecorrelation(true) .withClusterFactory( relOptCluster -> RelOptCluster.create(planner, relOptCluster.getRexBuilder())); final RelNode rel = tester1.convertSqlToRel(sql).rel; RuleSet ruleSet = RuleSets.ofList( EnumerableRules.ENUMERABLE_MERGE_JOIN_RULE, EnumerableRules.ENUMERABLE_JOIN_RULE, EnumerableRules.ENUMERABLE_PROJECT_RULE, EnumerableRules.ENUMERABLE_TABLE_SCAN_RULE, EnumerableRules.ENUMERABLE_SORT_RULE, AbstractConverter.ExpandConversionRule.INSTANCE); Program program = Programs.of(ruleSet); RelTraitSet toTraits = rel .getCluster() .traitSet() .replace(EnumerableConvention.INSTANCE); RelNode relAfter = program.run(planner, rel, toTraits, Collections.emptyList(), Collections.emptyList()); String planAfter = NL + RelOptUtil.toString(relAfter); getDiffRepos().assertEquals("planAfter", "${planAfter}", planAfter); } //~ Methods ---------------------------------------------------------------- @Override protected Tester createTester() { return super.createTester() .withConfig(SqlToRelConverter .configBuilder() .withHintStrategyTable(HintTools.HINT_STRATEGY_TABLE) .build()); } /** Sets the SQL statement for a test. */ public final Sql sql(String sql) { return new Sql(sql, tester); } private static boolean equalsStringList(List<String> l, List<String> r) { if (l.size() != r.size()) { return false; } for (String s : l) { if (!r.contains(s)) { return false; } } return true; } private static void assertHintsEquals(List<RelHint> expected, List<RelHint> actual) { assertArrayEquals(expected.toArray(new RelHint[0]), actual.toArray(new RelHint[0])); } //~ Inner Class ------------------------------------------------------------ /** A Mock rule to validate the hint. */ private static class MockJoinRule extends RelOptRule { public static final MockJoinRule INSTANCE = new MockJoinRule(); MockJoinRule() { super(operand(LogicalJoin.class, any()), "MockJoinRule"); } public void onMatch(RelOptRuleCall call) { LogicalJoin join = call.rel(0); assertThat(join.getHints().size(), is(1)); call.transformTo( LogicalJoin.create(join.getLeft(), join.getRight(), join.getHints(), join.getCondition(), join.getVariablesSet(), join.getJoinType())); } } /** A Mock rule to validate the hint. * This rule also converts the rel to EnumerableConvention. */ private static class MockEnumerableJoinRule extends ConverterRule { private final RelHint expectedHint; MockEnumerableJoinRule(RelHint hint) { super( LogicalJoin.class, Convention.NONE, EnumerableConvention.INSTANCE, "MockEnumerableJoinRule"); this.expectedHint = hint; } @Override public RelNode convert(RelNode rel) { LogicalJoin join = (LogicalJoin) rel; assertThat(join.getHints().size(), is(1)); assertThat(join.getHints().get(0), is(expectedHint)); List<RelNode> newInputs = new ArrayList<>(); for (RelNode input : join.getInputs()) { if (!(input.getConvention() instanceof EnumerableConvention)) { input = convert( input, input.getTraitSet() .replace(EnumerableConvention.INSTANCE)); } newInputs.add(input); } final RelOptCluster cluster = join.getCluster(); final RelNode left = newInputs.get(0); final RelNode right = newInputs.get(1); final JoinInfo info = join.analyzeCondition(); return EnumerableHashJoin.create( left, right, info.getEquiCondition(left, right, cluster.getRexBuilder()), join.getVariablesSet(), join.getJoinType()); } } /** A visitor to validate a hintable node has specific hint. **/ private static class ValidateHintVisitor extends RelVisitor { private RelHint expectedHint; private Class<?> clazz; /** * Creates the validate visitor. * * @param hint the hint to validate * @param clazz the node type to validate the hint with */ ValidateHintVisitor(RelHint hint, Class<?> clazz) { this.expectedHint = hint; this.clazz = clazz; } @Override public void visit( RelNode node, int ordinal, RelNode parent) { if (clazz.isInstance(node)) { Hintable rel = (Hintable) node; assertThat(rel.getHints().size(), is(1)); assertThat(rel.getHints().get(0), is(expectedHint)); } super.visit(node, ordinal, parent); } } /** Sql test tool. */ private static class Sql { private String sql; private Tester tester; private List<String> hintsCollect; Sql(String sql, Tester tester) { this.sql = sql; this.tester = tester; this.hintsCollect = new ArrayList<>(); } /** Create a new Sql instance with new tester * applied with the {@code transform}. **/ Sql withTester(UnaryOperator<Tester> transform) { return new Sql(this.sql, transform.apply(tester)); } void ok() { assertHintsEquals(sql, "${hints}"); } private void assertHintsEquals( String sql, String hint) { tester.getDiffRepos().assertEquals("sql", "${sql}", sql); String sql2 = tester.getDiffRepos().expand("sql", sql); final RelNode rel = tester.convertSqlToRel(sql2).project(); assertNotNull(rel); assertValid(rel); final HintCollector collector = new HintCollector(hintsCollect); rel.accept(collector); StringBuilder builder = new StringBuilder(NL); for (String hintLine : hintsCollect) { builder.append(hintLine).append(NL); } tester.getDiffRepos().assertEquals("hints", hint, builder.toString()); } void fails(String failedMsg) { try { tester.convertSqlToRel(sql); fail("Unexpected exception"); } catch (AssertionError e) { assertThat(e.getMessage(), is(failedMsg)); } } void warns(String expectWarning) { MockAppender appender = new MockAppender(); Logger logger = Logger.getRootLogger(); logger.addAppender(appender); try { tester.convertSqlToRel(sql); } finally { logger.removeAppender(appender); } List<String> warnings = appender.loggingEvents.stream() .filter(e -> e.getLevel() == Level.WARN) .map(LoggingEvent::getRenderedMessage) .collect(Collectors.toList()); assertThat(expectWarning, is(in(warnings))); } /** A shuttle to collect all the hints within the relational expression into a collection. */ private static class HintCollector extends RelShuttleImpl { private final List<String> hintsCollect; HintCollector(List<String> hintsCollect) { this.hintsCollect = hintsCollect; } @Override public RelNode visit(TableScan scan) { if (scan.getHints().size() > 0) { this.hintsCollect.add("TableScan:" + scan.getHints().toString()); } return super.visit(scan); } @Override public RelNode visit(LogicalJoin join) { if (join.getHints().size() > 0) { this.hintsCollect.add("LogicalJoin:" + join.getHints().toString()); } return super.visit(join); } @Override public RelNode visit(LogicalProject project) { if (project.getHints().size() > 0) { this.hintsCollect.add("Project:" + project.getHints().toString()); } return super.visit(project); } @Override public RelNode visit(LogicalAggregate aggregate) { if (aggregate.getHints().size() > 0) { this.hintsCollect.add("Aggregate:" + aggregate.getHints().toString()); } return super.visit(aggregate); } } } /** Mock appender to collect the logging events. */ private static class MockAppender extends AppenderSkeleton { public final List<LoggingEvent> loggingEvents = new ArrayList<>(); protected void append(org.apache.log4j.spi.LoggingEvent event) { loggingEvents.add(event); } public void close() { // no-op } public boolean requiresLayout() { return false; } } /** Define some tool members and methods for hints test. */ private static class HintTools { //~ Static fields/initializers --------------------------------------------- static final String HINT = "properties(k1='v1', k2='v2'), index(ename), no_hash_join"; static final RelHint PROPS_HINT = RelHint.builder("PROPERTIES") .hintOption("K1", "v1") .hintOption("K2", "v2") .build(); static final RelHint IDX_HINT = RelHint.builder("INDEX") .hintOption("ENAME") .build(); static final RelHint JOIN_HINT = RelHint.builder("NO_HASH_JOIN").build(); static final HintStrategyTable HINT_STRATEGY_TABLE = createHintStrategies(); //~ Methods ---------------------------------------------------------------- /** * Creates mock hint strategies. * * @return HintStrategyTable instance */ private static HintStrategyTable createHintStrategies() { return createHintStrategies(HintStrategyTable.builder()); } /** * Creates mock hint strategies with given builder. * * @return HintStrategyTable instance */ static HintStrategyTable createHintStrategies(HintStrategyTable.Builder builder) { return builder .hintStrategy("no_hash_join", HintPredicates.JOIN) .hintStrategy("time_zone", HintPredicates.SET_VAR) .hintStrategy("REPARTITION", HintPredicates.SET_VAR) .hintStrategy("index", HintPredicates.TABLE_SCAN) .hintStrategy("properties", HintPredicates.TABLE_SCAN) .hintStrategy( "resource", HintPredicates.or( HintPredicates.PROJECT, HintPredicates.AGGREGATE, HintPredicates.CALC)) .hintStrategy("AGG_STRATEGY", HintStrategy.builder(HintPredicates.AGGREGATE) .optionChecker( (hint, errorHandler) -> errorHandler.check( hint.listOptions.size() == 1 && (hint.listOptions.get(0).equalsIgnoreCase("ONE_PHASE") || hint.listOptions.get(0).equalsIgnoreCase("TWO_PHASE")), "Hint {} only allows single option, " + "allowed options: [ONE_PHASE, TWO_PHASE]", hint.hintName)).build()) .hintStrategy("use_hash_join", HintPredicates.and(HintPredicates.JOIN, joinWithFixedTableName())) .hintStrategy("use_merge_join", HintStrategy.builder( HintPredicates.and(HintPredicates.JOIN, joinWithFixedTableName())) .excludedRules(EnumerableRules.ENUMERABLE_JOIN_RULE).build()) .build(); } /** Returns a {@link HintPredicate} for join with specified table references. */ private static HintPredicate joinWithFixedTableName() { return (hint, rel) -> { if (!(rel instanceof LogicalJoin)) { return false; } LogicalJoin join = (LogicalJoin) rel; final List<String> tableNames = hint.listOptions; final List<String> inputTables = join.getInputs().stream() .filter(input -> input instanceof TableScan) .map(scan -> Util.last(scan.getTable().getQualifiedName())) .collect(Collectors.toList()); return equalsStringList(tableNames, inputTables); }; } /** Format the query with hint {@link #HINT}. */ static String withHint(String sql) { return String.format(Locale.ROOT, sql, HINT); } } }