Java Code Examples for org.apache.calcite.rel.core.Aggregate#getGroupCount()

The following examples show how to use org.apache.calcite.rel.core.Aggregate#getGroupCount() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: RelMdUtil.java    From Bats with Apache License 2.0 6 votes vote down vote up
/**
 * Takes a bitmap representing a set of input references and extracts the
 * ones that reference the group by columns in an aggregate.
 *
 * @param groupKey the original bitmap
 * @param aggRel   the aggregate
 * @param childKey sets bits from groupKey corresponding to group by columns
 */
public static void setAggChildKeys(
    ImmutableBitSet groupKey,
    Aggregate aggRel,
    ImmutableBitSet.Builder childKey) {
  List<AggregateCall> aggCalls = aggRel.getAggCallList();
  for (int bit : groupKey) {
    if (bit < aggRel.getGroupCount()) {
      // group by column
      childKey.set(bit);
    } else {
      // aggregate column -- set a bit for each argument being
      // aggregated
      AggregateCall agg = aggCalls.get(bit
          - (aggRel.getGroupCount() + aggRel.getIndicatorCount()));
      for (Integer arg : agg.getArgList()) {
        childKey.set(arg);
      }
    }
  }
}
 
Example 2
Source File: AggregateReduceFunctionsRule.java    From calcite with Apache License 2.0 6 votes vote down vote up
private AggregateCall createAggregateCallWithBinding(
    RelDataTypeFactory typeFactory,
    SqlAggFunction aggFunction,
    RelDataType operandType,
    Aggregate oldAggRel,
    AggregateCall oldCall,
    int argOrdinal,
    int filter) {
  final Aggregate.AggCallBinding binding =
      new Aggregate.AggCallBinding(typeFactory, aggFunction,
          ImmutableList.of(operandType), oldAggRel.getGroupCount(),
          filter >= 0);
  return AggregateCall.create(aggFunction,
      oldCall.isDistinct(),
      oldCall.isApproximate(),
      oldCall.ignoreNulls(),
      ImmutableIntList.of(argOrdinal),
      filter,
      oldCall.collation,
      aggFunction.inferReturnType(binding),
      null);
}
 
Example 3
Source File: RelMdColumnOrigins.java    From calcite with Apache License 2.0 6 votes vote down vote up
public Set<RelColumnOrigin> getColumnOrigins(Aggregate rel,
    RelMetadataQuery mq, int iOutputColumn) {
  if (iOutputColumn < rel.getGroupCount()) {
    // Group columns pass through directly.
    return mq.getColumnOrigins(rel.getInput(), iOutputColumn);
  }

  // Aggregate columns are derived from input columns
  AggregateCall call =
      rel.getAggCallList().get(iOutputColumn
              - rel.getGroupCount());

  final Set<RelColumnOrigin> set = new HashSet<>();
  for (Integer iInput : call.getArgList()) {
    Set<RelColumnOrigin> inputSet =
        mq.getColumnOrigins(rel.getInput(), iInput);
    inputSet = createDerivedColumnOrigins(inputSet);
    if (inputSet != null) {
      set.addAll(inputSet);
    }
  }
  return set;
}
 
Example 4
Source File: FlinkAggregateRemoveRule.java    From flink with Apache License 2.0 6 votes vote down vote up
@Override
public boolean matches(RelOptRuleCall call) {
	final Aggregate aggregate = call.rel(0);
	final RelNode input = call.rel(1);
	if (aggregate.getGroupCount() == 0 || aggregate.indicator ||
			aggregate.getGroupType() != Aggregate.Group.SIMPLE) {
		return false;
	}
	for (AggregateCall aggCall : aggregate.getAggCallList()) {
		SqlKind aggCallKind = aggCall.getAggregation().getKind();
		// TODO supports more AggregateCalls
		boolean isAllowAggCall = aggCallKind == SqlKind.SUM ||
				aggCallKind == SqlKind.MIN ||
				aggCallKind == SqlKind.MAX ||
				aggCall.getAggregation() instanceof SqlAuxiliaryGroupAggFunction;
		if (!isAllowAggCall || aggCall.filterArg >= 0 || aggCall.getArgList().size() != 1) {
			return false;
		}
	}

	final RelMetadataQuery mq = call.getMetadataQuery();
	return SqlFunctions.isTrue(mq.areColumnsUnique(input, aggregate.getGroupSet()));
}
 
Example 5
Source File: FlinkAggregateRemoveRule.java    From flink with Apache License 2.0 6 votes vote down vote up
@Override
public boolean matches(RelOptRuleCall call) {
	final Aggregate aggregate = call.rel(0);
	final RelNode input = call.rel(1);
	if (aggregate.getGroupCount() == 0 || aggregate.indicator ||
			aggregate.getGroupType() != Aggregate.Group.SIMPLE) {
		return false;
	}
	for (AggregateCall aggCall : aggregate.getAggCallList()) {
		SqlKind aggCallKind = aggCall.getAggregation().getKind();
		// TODO supports more AggregateCalls
		boolean isAllowAggCall = aggCallKind == SqlKind.SUM ||
				aggCallKind == SqlKind.MIN ||
				aggCallKind == SqlKind.MAX ||
				aggCall.getAggregation() instanceof SqlAuxiliaryGroupAggFunction;
		if (!isAllowAggCall || aggCall.filterArg >= 0 || aggCall.getArgList().size() != 1) {
			return false;
		}
	}

	final RelMetadataQuery mq = call.getMetadataQuery();
	return SqlFunctions.isTrue(mq.areColumnsUnique(input, aggregate.getGroupSet()));
}
 
Example 6
Source File: PruneScanRule.java    From Bats with Apache License 2.0 6 votes vote down vote up
@Override
public boolean matches(RelOptRuleCall call) {
  Aggregate aggregate = call.rel(0);
  TableScan scan = call.rel(1);

  if (!isQualifiedFilePruning(scan)
      || scan.getRowType().getFieldCount() != aggregate.getRowType().getFieldCount()) {
    return false;
  }

  List<String> fieldNames = scan.getRowType().getFieldNames();
  // Check if select contains partition columns (dir0, dir1, dir2,..., dirN) only
  for (String field : fieldNames) {
    if (!dirPattern.matcher(field).matches()) {
      return false;
    }
  }

  return scan.isDistinct() || aggregate.getGroupCount() > 0;
}
 
Example 7
Source File: AggregateReduceFunctionsRule.java    From Bats with Apache License 2.0 6 votes vote down vote up
private AggregateCall createAggregateCallWithBinding(
    RelDataTypeFactory typeFactory,
    SqlAggFunction aggFunction,
    RelDataType operandType,
    Aggregate oldAggRel,
    AggregateCall oldCall,
    int argOrdinal,
    int filter) {
  final Aggregate.AggCallBinding binding =
      new Aggregate.AggCallBinding(typeFactory, aggFunction,
          ImmutableList.of(operandType), oldAggRel.getGroupCount(),
          filter >= 0);
  return AggregateCall.create(aggFunction,
      oldCall.isDistinct(),
      oldCall.isApproximate(),
      ImmutableIntList.of(argOrdinal),
      filter,
      oldCall.collation,
      aggFunction.inferReturnType(binding),
      null);
}
 
Example 8
Source File: AbstractMaterializedViewRule.java    From Bats with Apache License 2.0 5 votes vote down vote up
/**
 * If the node is an Aggregate, it returns a list of references to the grouping columns.
 * Otherwise, it returns a list of references to all columns in the node.
 * The returned list is immutable.
 */
private static List<RexNode> extractReferences(RexBuilder rexBuilder, RelNode node) {
    ImmutableList.Builder<RexNode> exprs = ImmutableList.builder();
    if (node instanceof Aggregate) {
        Aggregate aggregate = (Aggregate) node;
        for (int i = 0; i < aggregate.getGroupCount(); i++) {
            exprs.add(rexBuilder.makeInputRef(aggregate, i));
        }
    } else {
        for (int i = 0; i < node.getRowType().getFieldCount(); i++) {
            exprs.add(rexBuilder.makeInputRef(node, i));
        }
    }
    return exprs.build();
}
 
Example 9
Source File: AggregateRemoveRule.java    From calcite with Apache License 2.0 5 votes vote down vote up
private static boolean isAggregateSupported(Aggregate aggregate) {
  if (aggregate.getGroupType() != Aggregate.Group.SIMPLE
      || aggregate.getGroupCount() == 0) {
    return false;
  }
  // If any aggregate functions do not support splitting, bail out.
  for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
    if (aggregateCall.filterArg >= 0
        || aggregateCall.getAggregation()
            .unwrap(SqlSplittableAggFunction.class) == null) {
      return false;
    }
  }
  return true;
}
 
Example 10
Source File: MaterializedViewRule.java    From calcite with Apache License 2.0 5 votes vote down vote up
/**
 * If the node is an Aggregate, it returns a list of references to the grouping columns.
 * Otherwise, it returns a list of references to all columns in the node.
 * The returned list is immutable.
 */
protected List<RexNode> extractReferences(RexBuilder rexBuilder, RelNode node) {
  ImmutableList.Builder<RexNode> exprs = ImmutableList.builder();
  if (node instanceof Aggregate) {
    Aggregate aggregate = (Aggregate) node;
    for (int i = 0; i < aggregate.getGroupCount(); i++) {
      exprs.add(rexBuilder.makeInputRef(aggregate, i));
    }
  } else {
    for (int i = 0; i < node.getRowType().getFieldCount(); i++) {
      exprs.add(rexBuilder.makeInputRef(node, i));
    }
  }
  return exprs.build();
}
 
Example 11
Source File: AggregateReduceFunctionsRule.java    From Bats with Apache License 2.0 4 votes vote down vote up
private RexNode reduceSum(
    Aggregate oldAggRel,
    AggregateCall oldCall,
    List<AggregateCall> newCalls,
    Map<AggregateCall, RexNode> aggCallMapping) {
  final int nGroups = oldAggRel.getGroupCount();
  RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
  int arg = oldCall.getArgList().get(0);
  RelDataType argType =
      getFieldType(
          oldAggRel.getInput(),
          arg);

  final AggregateCall sumZeroCall =
      AggregateCall.create(SqlStdOperatorTable.SUM0, oldCall.isDistinct(),
          oldCall.isApproximate(), oldCall.getArgList(), oldCall.filterArg,
          oldCall.collation, oldAggRel.getGroupCount(), oldAggRel.getInput(),
          null, oldCall.name);
  final AggregateCall countCall =
      AggregateCall.create(SqlStdOperatorTable.COUNT,
          oldCall.isDistinct(),
          oldCall.isApproximate(),
          oldCall.getArgList(),
          oldCall.filterArg,
          oldCall.collation,
          oldAggRel.getGroupCount(),
          oldAggRel,
          null,
          null);

  // NOTE:  these references are with respect to the output
  // of newAggRel
  RexNode sumZeroRef =
      rexBuilder.addAggCall(sumZeroCall,
          nGroups,
          oldAggRel.indicator,
          newCalls,
          aggCallMapping,
          ImmutableList.of(argType));
  if (!oldCall.getType().isNullable()) {
    // If SUM(x) is not nullable, the validator must have determined that
    // nulls are impossible (because the group is never empty and x is never
    // null). Therefore we translate to SUM0(x).
    return sumZeroRef;
  }
  RexNode countRef =
      rexBuilder.addAggCall(countCall,
          nGroups,
          oldAggRel.indicator,
          newCalls,
          aggCallMapping,
          ImmutableList.of(argType));
  return rexBuilder.makeCall(SqlStdOperatorTable.CASE,
      rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
          countRef, rexBuilder.makeExactLiteral(BigDecimal.ZERO)),
      rexBuilder.makeCast(sumZeroRef.getType(), rexBuilder.constantNull()),
      sumZeroRef);
}
 
Example 12
Source File: AggregateProjectMergeRule.java    From calcite with Apache License 2.0 4 votes vote down vote up
public static RelNode apply(RelOptRuleCall call, Aggregate aggregate,
    Project project) {
  // Find all fields which we need to be straightforward field projections.
  final Set<Integer> interestingFields = RelOptUtil.getAllFields(aggregate);

  // Build the map from old to new; abort if any entry is not a
  // straightforward field projection.
  final Map<Integer, Integer> map = new HashMap<>();
  for (int source : interestingFields) {
    final RexNode rex = project.getProjects().get(source);
    if (!(rex instanceof RexInputRef)) {
      return null;
    }
    map.put(source, ((RexInputRef) rex).getIndex());
  }

  final ImmutableBitSet newGroupSet = aggregate.getGroupSet().permute(map);
  ImmutableList<ImmutableBitSet> newGroupingSets = null;
  if (aggregate.getGroupType() != Group.SIMPLE) {
    newGroupingSets =
        ImmutableBitSet.ORDERING.immutableSortedCopy(
            ImmutableBitSet.permute(aggregate.getGroupSets(), map));
  }

  final ImmutableList.Builder<AggregateCall> aggCalls =
      ImmutableList.builder();
  final int sourceCount = aggregate.getInput().getRowType().getFieldCount();
  final int targetCount = project.getInput().getRowType().getFieldCount();
  final Mappings.TargetMapping targetMapping =
      Mappings.target(map, sourceCount, targetCount);
  for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
    aggCalls.add(aggregateCall.transform(targetMapping));
  }

  final Aggregate newAggregate =
      aggregate.copy(aggregate.getTraitSet(), project.getInput(),
          newGroupSet, newGroupingSets, aggCalls.build());

  // Add a project if the group set is not in the same order or
  // contains duplicates.
  final RelBuilder relBuilder = call.builder();
  relBuilder.push(newAggregate);
  final List<Integer> newKeys =
      Lists.transform(aggregate.getGroupSet().asList(), map::get);
  if (!newKeys.equals(newGroupSet.asList())) {
    final List<Integer> posList = new ArrayList<>();
    for (int newKey : newKeys) {
      posList.add(newGroupSet.indexOf(newKey));
    }
    for (int i = newAggregate.getGroupCount();
         i < newAggregate.getRowType().getFieldCount(); i++) {
      posList.add(i);
    }
    relBuilder.project(relBuilder.fields(posList));
  }

  return relBuilder.build();
}
 
Example 13
Source File: AggregateReduceFunctionsRule.java    From Bats with Apache License 2.0 4 votes vote down vote up
private RexNode reduceAvg(
    Aggregate oldAggRel,
    AggregateCall oldCall,
    List<AggregateCall> newCalls,
    Map<AggregateCall, RexNode> aggCallMapping,
    List<RexNode> inputExprs) {
  final int nGroups = oldAggRel.getGroupCount();
  final RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
  final int iAvgInput = oldCall.getArgList().get(0);
  final RelDataType avgInputType =
      getFieldType(
          oldAggRel.getInput(),
          iAvgInput);
  final AggregateCall sumCall =
      AggregateCall.create(SqlStdOperatorTable.SUM,
          oldCall.isDistinct(),
          oldCall.isApproximate(),
          oldCall.getArgList(),
          oldCall.filterArg,
          oldCall.collation,
          oldAggRel.getGroupCount(),
          oldAggRel.getInput(),
          null,
          null);
  final AggregateCall countCall =
      AggregateCall.create(SqlStdOperatorTable.COUNT,
          oldCall.isDistinct(),
          oldCall.isApproximate(),
          oldCall.getArgList(),
          oldCall.filterArg,
          oldCall.collation,
          oldAggRel.getGroupCount(),
          oldAggRel.getInput(),
          null,
          null);

  // NOTE:  these references are with respect to the output
  // of newAggRel
  RexNode numeratorRef =
      rexBuilder.addAggCall(sumCall,
          nGroups,
          oldAggRel.indicator,
          newCalls,
          aggCallMapping,
          ImmutableList.of(avgInputType));
  final RexNode denominatorRef =
      rexBuilder.addAggCall(countCall,
          nGroups,
          oldAggRel.indicator,
          newCalls,
          aggCallMapping,
          ImmutableList.of(avgInputType));

  final RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
  final RelDataType avgType = typeFactory.createTypeWithNullability(
      oldCall.getType(), numeratorRef.getType().isNullable());
  numeratorRef = rexBuilder.ensureType(avgType, numeratorRef, true);
  final RexNode divideRef =
      rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, numeratorRef, denominatorRef);
  return rexBuilder.makeCast(oldCall.getType(), divideRef);
}
 
Example 14
Source File: DrillReduceAggregatesRule.java    From Bats with Apache License 2.0 4 votes vote down vote up
/**
 * Reduces all calls to AVG, STDDEV_POP, STDDEV_SAMP, VAR_POP, VAR_SAMP in
 * the aggregates list to.
 *
 * <p>It handles newly generated common subexpressions since this was done
 * at the sql2rel stage.
 */
private void reduceAggs(
    RelOptRuleCall ruleCall,
    Aggregate oldAggRel) {
  RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();

  List<AggregateCall> oldCalls = oldAggRel.getAggCallList();
  final int nGroups = oldAggRel.getGroupCount();

  List<AggregateCall> newCalls = new ArrayList<>();
  Map<AggregateCall, RexNode> aggCallMapping =
      new HashMap<>();

  List<RexNode> projList = new ArrayList<>();

  // pass through group key
  for (int i = 0; i < nGroups; ++i) {
    projList.add(
        rexBuilder.makeInputRef(
            getFieldType(oldAggRel, i),
            i));
  }

  // List of input expressions. If a particular aggregate needs more, it
  // will add an expression to the end, and we will create an extra
  // project.
  RelNode input = oldAggRel.getInput();
  List<RexNode> inputExprs = new ArrayList<>();
  for (RelDataTypeField field : input.getRowType().getFieldList()) {
    inputExprs.add(
        rexBuilder.makeInputRef(
            field.getType(), inputExprs.size()));
  }

  // create new agg function calls and rest of project list together
  for (AggregateCall oldCall : oldCalls) {
    projList.add(
        reduceAgg(
            oldAggRel, oldCall, newCalls, aggCallMapping, inputExprs));
  }

  final int extraArgCount =
      inputExprs.size() - input.getRowType().getFieldCount();
  if (extraArgCount > 0) {
    input =
        relBuilderFactory
            .create(input.getCluster(), null)
            .push(input)
            .projectNamed(
                inputExprs,
                CompositeList.of(
                input.getRowType().getFieldNames(),
                Collections.nCopies(
                    extraArgCount,
                    null)),
                true)
            .build();
  }
  Aggregate newAggRel =
      newAggregateRel(
          oldAggRel, input, newCalls);

  RelNode projectRel =
      relBuilderFactory
          .create(newAggRel.getCluster(), null)
          .push(newAggRel)
          .projectNamed(projList, oldAggRel.getRowType().getFieldNames(), true)
          .build();

  ruleCall.transformTo(projectRel);
}
 
Example 15
Source File: DrillReduceAggregatesRule.java    From Bats with Apache License 2.0 4 votes vote down vote up
private RexNode reduceAvg(
    Aggregate oldAggRel,
    AggregateCall oldCall,
    List<AggregateCall> newCalls,
    Map<AggregateCall, RexNode> aggCallMapping) {
  final PlannerSettings plannerSettings = (PlannerSettings) oldAggRel.getCluster().getPlanner().getContext();
  final boolean isInferenceEnabled = plannerSettings.isTypeInferenceEnabled();
  final int nGroups = oldAggRel.getGroupCount();
  RelDataTypeFactory typeFactory =
      oldAggRel.getCluster().getTypeFactory();
  RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
  int iAvgInput = oldCall.getArgList().get(0);
  RelDataType avgInputType =
      getFieldType(
          oldAggRel.getInput(),
          iAvgInput);
  RelDataType sumType =
      TypeInferenceUtils.getDrillSqlReturnTypeInference(SqlKind.SUM.name(),
          ImmutableList.of())
        .inferReturnType(oldCall.createBinding(oldAggRel));
  sumType =
      typeFactory.createTypeWithNullability(
          sumType,
          sumType.isNullable() || nGroups == 0);
  SqlAggFunction sumAgg =
      new DrillCalciteSqlAggFunctionWrapper(new SqlSumEmptyIsZeroAggFunction(), sumType);
  AggregateCall sumCall = AggregateCall.create(sumAgg, oldCall.isDistinct(),
      oldCall.isApproximate(), oldCall.getArgList(), -1, sumType, null);
  final SqlCountAggFunction countAgg = (SqlCountAggFunction) SqlStdOperatorTable.COUNT;
  final RelDataType countType = countAgg.getReturnType(typeFactory);
  AggregateCall countCall = AggregateCall.create(countAgg, oldCall.isDistinct(),
      oldCall.isApproximate(), oldCall.getArgList(), -1, countType, null);

  RexNode tmpsumRef =
      rexBuilder.addAggCall(
          sumCall,
          nGroups,
          oldAggRel.indicator,
          newCalls,
          aggCallMapping,
          ImmutableList.of(avgInputType));

  RexNode tmpcountRef =
      rexBuilder.addAggCall(
          countCall,
          nGroups,
          oldAggRel.indicator,
          newCalls,
          aggCallMapping,
          ImmutableList.of(avgInputType));

  RexNode n = rexBuilder.makeCall(SqlStdOperatorTable.CASE,
      rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
          tmpcountRef, rexBuilder.makeExactLiteral(BigDecimal.ZERO)),
          rexBuilder.constantNull(),
          tmpsumRef);

  // NOTE:  these references are with respect to the output
  // of newAggRel
  /*
  RexNode numeratorRef =
      rexBuilder.makeCall(CastHighOp,
        rexBuilder.addAggCall(
            sumCall,
            nGroups,
            newCalls,
            aggCallMapping,
            ImmutableList.of(avgInputType))
      );
  */
  RexNode numeratorRef = rexBuilder.makeCall(CastHighOp,  n);

  RexNode denominatorRef =
      rexBuilder.addAggCall(
          countCall,
          nGroups,
          oldAggRel.indicator,
          newCalls,
          aggCallMapping,
          ImmutableList.of(avgInputType));
  if (isInferenceEnabled) {
    return rexBuilder.makeCall(
        new DrillSqlOperator(
            "divide",
            2,
            true,
            oldCall.getType(), false),
        numeratorRef,
        denominatorRef);
  } else {
    final RexNode divideRef =
        rexBuilder.makeCall(
            SqlStdOperatorTable.DIVIDE,
            numeratorRef,
            denominatorRef);
    return rexBuilder.makeCast(
        typeFactory.createSqlType(SqlTypeName.ANY), divideRef);
  }
}
 
Example 16
Source File: DrillReduceAggregatesRule.java    From Bats with Apache License 2.0 4 votes vote down vote up
private RexNode reduceSum(
    Aggregate oldAggRel,
    AggregateCall oldCall,
    List<AggregateCall> newCalls,
    Map<AggregateCall, RexNode> aggCallMapping) {
  final PlannerSettings plannerSettings = (PlannerSettings) oldAggRel.getCluster().getPlanner().getContext();
  final boolean isInferenceEnabled = plannerSettings.isTypeInferenceEnabled();
  final int nGroups = oldAggRel.getGroupCount();
  RelDataTypeFactory typeFactory =
      oldAggRel.getCluster().getTypeFactory();
  RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
  int arg = oldCall.getArgList().get(0);
  RelDataType argType =
      getFieldType(
          oldAggRel.getInput(),
          arg);
  final RelDataType sumType;
  final SqlAggFunction sumZeroAgg;
  if (isInferenceEnabled) {
    sumType = oldCall.getType();
  } else {
    sumType =
        typeFactory.createTypeWithNullability(
            oldCall.getType(), argType.isNullable());
  }
  sumZeroAgg = new DrillCalciteSqlAggFunctionWrapper(
      new SqlSumEmptyIsZeroAggFunction(), sumType);
  AggregateCall sumZeroCall = AggregateCall.create(sumZeroAgg, oldCall.isDistinct(),
      oldCall.isApproximate(), oldCall.getArgList(), -1, sumType, null);
  final SqlCountAggFunction countAgg = (SqlCountAggFunction) SqlStdOperatorTable.COUNT;
  final RelDataType countType = countAgg.getReturnType(typeFactory);
  AggregateCall countCall = AggregateCall.create(countAgg, oldCall.isDistinct(),
      oldCall.isApproximate(), oldCall.getArgList(), -1, countType, null);
  // NOTE:  these references are with respect to the output
  // of newAggRel
  RexNode sumZeroRef =
      rexBuilder.addAggCall(
          sumZeroCall,
          nGroups,
          oldAggRel.indicator,
          newCalls,
          aggCallMapping,
          ImmutableList.of(argType));
  if (!oldCall.getType().isNullable()) {
    // If SUM(x) is not nullable, the validator must have determined that
    // nulls are impossible (because the group is never empty and x is never
    // null). Therefore we translate to SUM0(x).
    return sumZeroRef;
  }
  RexNode countRef =
      rexBuilder.addAggCall(
          countCall,
          nGroups,
          oldAggRel.indicator,
          newCalls,
          aggCallMapping,
          ImmutableList.of(argType));
  return rexBuilder.makeCall(SqlStdOperatorTable.CASE,
      rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
          countRef, rexBuilder.makeExactLiteral(BigDecimal.ZERO)),
          rexBuilder.constantNull(),
          sumZeroRef);
}
 
Example 17
Source File: GroupByPartitionRule.java    From Mycat2 with GNU General Public License v3.0 4 votes vote down vote up
private static boolean test(Aggregate r) {
    return r.getGroupCount() == 1 && r.getAggCallList().size() == 1 &&
            SUPPORTED_AGGREGATES.containsKey(r.getAggCallList().get(0).getAggregation().kind);
}
 
Example 18
Source File: AggregateReduceFunctionsRule.java    From Bats with Apache License 2.0 4 votes vote down vote up
/**
 * Reduces all calls to AVG, STDDEV_POP, STDDEV_SAMP, VAR_POP, VAR_SAMP in
 * the aggregates list to.
 *
 * <p>It handles newly generated common subexpressions since this was done
 * at the sql2rel stage.
 */
private void reduceAggs(
    RelOptRuleCall ruleCall,
    Aggregate oldAggRel) {
  RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();

  List<AggregateCall> oldCalls = oldAggRel.getAggCallList();
  final int groupCount = oldAggRel.getGroupCount();
  final int indicatorCount = oldAggRel.getIndicatorCount();

  final List<AggregateCall> newCalls = new ArrayList<>();
  final Map<AggregateCall, RexNode> aggCallMapping = new HashMap<>();

  final List<RexNode> projList = new ArrayList<>();

  // pass through group key (+ indicators if present)
  for (int i = 0; i < groupCount + indicatorCount; ++i) {
    projList.add(
        rexBuilder.makeInputRef(
            getFieldType(oldAggRel, i),
            i));
  }

  // List of input expressions. If a particular aggregate needs more, it
  // will add an expression to the end, and we will create an extra
  // project.
  final RelBuilder relBuilder = ruleCall.builder();
  relBuilder.push(oldAggRel.getInput());
  final List<RexNode> inputExprs = new ArrayList<>(relBuilder.fields());

  // create new agg function calls and rest of project list together
  for (AggregateCall oldCall : oldCalls) {
    projList.add(
        reduceAgg(
            oldAggRel, oldCall, newCalls, aggCallMapping, inputExprs));
  }

  final int extraArgCount =
      inputExprs.size() - relBuilder.peek().getRowType().getFieldCount();
  if (extraArgCount > 0) {
    relBuilder.project(inputExprs,
        CompositeList.of(
            relBuilder.peek().getRowType().getFieldNames(),
            Collections.nCopies(extraArgCount, null)));
  }
  newAggregateRel(relBuilder, oldAggRel, newCalls);
  newCalcRel(relBuilder, oldAggRel.getRowType(), projList);
  ruleCall.transformTo(relBuilder.build());
}
 
Example 19
Source File: AggregateReduceFunctionsRule.java    From calcite with Apache License 2.0 4 votes vote down vote up
/**
 * Reduces calls to functions AVG, SUM, STDDEV_POP, STDDEV_SAMP, VAR_POP,
 * VAR_SAMP, COVAR_POP, COVAR_SAMP, REGR_SXX, REGR_SYY if the function is
 * present in {@link AggregateReduceFunctionsRule#functionsToReduce}
 *
 * <p>It handles newly generated common subexpressions since this was done
 * at the sql2rel stage.
 */
private void reduceAggs(
    RelOptRuleCall ruleCall,
    Aggregate oldAggRel) {
  RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();

  List<AggregateCall> oldCalls = oldAggRel.getAggCallList();
  final int groupCount = oldAggRel.getGroupCount();

  final List<AggregateCall> newCalls = new ArrayList<>();
  final Map<AggregateCall, RexNode> aggCallMapping = new HashMap<>();

  final List<RexNode> projList = new ArrayList<>();

  // pass through group key
  for (int i = 0; i < groupCount; ++i) {
    projList.add(
        rexBuilder.makeInputRef(
            getFieldType(oldAggRel, i),
            i));
  }

  // List of input expressions. If a particular aggregate needs more, it
  // will add an expression to the end, and we will create an extra
  // project.
  final RelBuilder relBuilder = ruleCall.builder();
  relBuilder.push(oldAggRel.getInput());
  final List<RexNode> inputExprs = new ArrayList<>(relBuilder.fields());

  // create new agg function calls and rest of project list together
  for (AggregateCall oldCall : oldCalls) {
    projList.add(
        reduceAgg(
            oldAggRel, oldCall, newCalls, aggCallMapping, inputExprs));
  }

  final int extraArgCount =
      inputExprs.size() - relBuilder.peek().getRowType().getFieldCount();
  if (extraArgCount > 0) {
    relBuilder.project(inputExprs,
        CompositeList.of(
            relBuilder.peek().getRowType().getFieldNames(),
            Collections.nCopies(extraArgCount, null)));
  }
  newAggregateRel(relBuilder, oldAggRel, newCalls);
  newCalcRel(relBuilder, oldAggRel.getRowType(), projList);
  ruleCall.transformTo(relBuilder.build());
}
 
Example 20
Source File: AggregateReduceFunctionsRule.java    From calcite with Apache License 2.0 4 votes vote down vote up
private RexNode reduceSum(
    Aggregate oldAggRel,
    AggregateCall oldCall,
    List<AggregateCall> newCalls,
    Map<AggregateCall, RexNode> aggCallMapping) {
  final int nGroups = oldAggRel.getGroupCount();
  RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
  int arg = oldCall.getArgList().get(0);
  RelDataType argType =
      getFieldType(
          oldAggRel.getInput(),
          arg);

  final AggregateCall sumZeroCall =
      AggregateCall.create(SqlStdOperatorTable.SUM0, oldCall.isDistinct(),
          oldCall.isApproximate(), oldCall.ignoreNulls(),
          oldCall.getArgList(), oldCall.filterArg,
          oldCall.collation, oldAggRel.getGroupCount(), oldAggRel.getInput(),
          null, oldCall.name);
  final AggregateCall countCall =
      AggregateCall.create(SqlStdOperatorTable.COUNT,
          oldCall.isDistinct(),
          oldCall.isApproximate(),
          oldCall.ignoreNulls(),
          oldCall.getArgList(),
          oldCall.filterArg,
          oldCall.collation,
          oldAggRel.getGroupCount(),
          oldAggRel,
          null,
          null);

  // NOTE:  these references are with respect to the output
  // of newAggRel
  RexNode sumZeroRef =
      rexBuilder.addAggCall(sumZeroCall,
          nGroups,
          newCalls,
          aggCallMapping,
          ImmutableList.of(argType));
  if (!oldCall.getType().isNullable()) {
    // If SUM(x) is not nullable, the validator must have determined that
    // nulls are impossible (because the group is never empty and x is never
    // null). Therefore we translate to SUM0(x).
    return sumZeroRef;
  }
  RexNode countRef =
      rexBuilder.addAggCall(countCall,
          nGroups,
          newCalls,
          aggCallMapping,
          ImmutableList.of(argType));
  return rexBuilder.makeCall(SqlStdOperatorTable.CASE,
      rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
          countRef, rexBuilder.makeExactLiteral(BigDecimal.ZERO)),
      rexBuilder.makeNullLiteral(sumZeroRef.getType()),
      sumZeroRef);
}