Java Code Examples for org.apache.calcite.sql.fun.SqlStdOperatorTable#COUNT

The following examples show how to use org.apache.calcite.sql.fun.SqlStdOperatorTable#COUNT . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: Lattice.java    From Bats with Apache License 2.0 5 votes vote down vote up
private SqlAggFunction resolveAgg(String aggName) {
  if (aggName.equalsIgnoreCase("count")) {
    return SqlStdOperatorTable.COUNT;
  } else if (aggName.equalsIgnoreCase("sum")) {
    return SqlStdOperatorTable.SUM;
  } else {
    throw new RuntimeException("Unknown lattice aggregate function "
        + aggName);
  }
}
 
Example 2
Source File: AbstractMaterializedViewRule.java    From Bats with Apache License 2.0 5 votes vote down vote up
/**
 * Get rollup aggregation function.
 */
protected SqlAggFunction getRollup(SqlAggFunction aggregation) {
    if (aggregation == SqlStdOperatorTable.SUM || aggregation == SqlStdOperatorTable.MIN
            || aggregation == SqlStdOperatorTable.MAX || aggregation == SqlStdOperatorTable.SUM0
            || aggregation == SqlStdOperatorTable.ANY_VALUE) {
        return aggregation;
    } else if (aggregation == SqlStdOperatorTable.COUNT) {
        return SqlStdOperatorTable.SUM0;
    } else {
        return null;
    }
}
 
Example 3
Source File: AggregateUnionTransposeRule.java    From Bats with Apache License 2.0 5 votes vote down vote up
private List<AggregateCall> transformAggCalls(RelNode input, int groupCount,
    List<AggregateCall> origCalls) {
  final List<AggregateCall> newCalls = new ArrayList<>();
  for (Ord<AggregateCall> ord : Ord.zip(origCalls)) {
    final AggregateCall origCall = ord.e;
    if (origCall.isDistinct()
        || !SUPPORTED_AGGREGATES.containsKey(origCall.getAggregation()
            .getClass())) {
      return null;
    }
    final SqlAggFunction aggFun;
    final RelDataType aggType;
    if (origCall.getAggregation() == SqlStdOperatorTable.COUNT) {
      aggFun = SqlStdOperatorTable.SUM0;
      // count(any) is always not null, however nullability of sum might
      // depend on the number of columns in GROUP BY.
      // Here we use SUM0 since we are sure we will not face nullable
      // inputs nor we'll face empty set.
      aggType = null;
    } else {
      aggFun = origCall.getAggregation();
      aggType = origCall.getType();
    }
    AggregateCall newCall =
        AggregateCall.create(aggFun, origCall.isDistinct(),
            origCall.isApproximate(),
            ImmutableList.of(groupCount + ord.i), -1,
            origCall.collation,
            groupCount,
            input,
            aggType,
            origCall.getName());
    newCalls.add(newCall);
  }
  return newCalls;
}
 
Example 4
Source File: SubstitutionVisitor.java    From Bats with Apache License 2.0 5 votes vote down vote up
public static SqlAggFunction getRollup(SqlAggFunction aggregation) {
    if (aggregation == SqlStdOperatorTable.SUM || aggregation == SqlStdOperatorTable.MIN
            || aggregation == SqlStdOperatorTable.MAX || aggregation == SqlStdOperatorTable.SUM0
            || aggregation == SqlStdOperatorTable.ANY_VALUE) {
        return aggregation;
    } else if (aggregation == SqlStdOperatorTable.COUNT) {
        return SqlStdOperatorTable.SUM0;
    } else {
        return null;
    }
}
 
Example 5
Source File: Lattice.java    From calcite with Apache License 2.0 5 votes vote down vote up
private SqlAggFunction resolveAgg(String aggName) {
  if (aggName.equalsIgnoreCase("count")) {
    return SqlStdOperatorTable.COUNT;
  } else if (aggName.equalsIgnoreCase("sum")) {
    return SqlStdOperatorTable.SUM;
  } else {
    throw new RuntimeException("Unknown lattice aggregate function "
        + aggName);
  }
}
 
Example 6
Source File: MaterializedViewAggregateRule.java    From calcite with Apache License 2.0 5 votes vote down vote up
/**
 * Get rollup aggregation function.
 */
protected SqlAggFunction getRollup(SqlAggFunction aggregation) {
  if (aggregation == SqlStdOperatorTable.SUM
      || aggregation == SqlStdOperatorTable.MIN
      || aggregation == SqlStdOperatorTable.MAX
      || aggregation == SqlStdOperatorTable.SUM0
      || aggregation == SqlStdOperatorTable.ANY_VALUE) {
    return aggregation;
  } else if (aggregation == SqlStdOperatorTable.COUNT) {
    return SqlStdOperatorTable.SUM0;
  } else {
    return null;
  }
}
 
Example 7
Source File: AggregateUnionTransposeRule.java    From calcite with Apache License 2.0 5 votes vote down vote up
private List<AggregateCall> transformAggCalls(RelNode input, int groupCount,
    List<AggregateCall> origCalls) {
  final List<AggregateCall> newCalls = new ArrayList<>();
  for (Ord<AggregateCall> ord : Ord.zip(origCalls)) {
    final AggregateCall origCall = ord.e;
    if (origCall.isDistinct()
        || !SUPPORTED_AGGREGATES.containsKey(origCall.getAggregation()
            .getClass())) {
      return null;
    }
    final SqlAggFunction aggFun;
    final RelDataType aggType;
    if (origCall.getAggregation() == SqlStdOperatorTable.COUNT) {
      aggFun = SqlStdOperatorTable.SUM0;
      // count(any) is always not null, however nullability of sum might
      // depend on the number of columns in GROUP BY.
      // Here we use SUM0 since we are sure we will not face nullable
      // inputs nor we'll face empty set.
      aggType = null;
    } else {
      aggFun = origCall.getAggregation();
      aggType = origCall.getType();
    }
    AggregateCall newCall =
        AggregateCall.create(aggFun, origCall.isDistinct(),
            origCall.isApproximate(), origCall.ignoreNulls(),
            ImmutableList.of(groupCount + ord.i), -1, origCall.collation,
            groupCount, input, aggType, origCall.getName());
    newCalls.add(newCall);
  }
  return newCalls;
}
 
Example 8
Source File: SubstitutionVisitor.java    From calcite with Apache License 2.0 5 votes vote down vote up
public static SqlAggFunction getRollup(SqlAggFunction aggregation) {
  if (aggregation == SqlStdOperatorTable.SUM
      || aggregation == SqlStdOperatorTable.MIN
      || aggregation == SqlStdOperatorTable.MAX
      || aggregation == SqlStdOperatorTable.SUM0
      || aggregation == SqlStdOperatorTable.ANY_VALUE) {
    return aggregation;
  } else if (aggregation == SqlStdOperatorTable.COUNT) {
    return SqlStdOperatorTable.SUM0;
  } else {
    return null;
  }
}
 
Example 9
Source File: DrillReduceAggregatesRule.java    From Bats with Apache License 2.0 4 votes vote down vote up
private RexNode reduceAvg(
    Aggregate oldAggRel,
    AggregateCall oldCall,
    List<AggregateCall> newCalls,
    Map<AggregateCall, RexNode> aggCallMapping) {
  final PlannerSettings plannerSettings = (PlannerSettings) oldAggRel.getCluster().getPlanner().getContext();
  final boolean isInferenceEnabled = plannerSettings.isTypeInferenceEnabled();
  final int nGroups = oldAggRel.getGroupCount();
  RelDataTypeFactory typeFactory =
      oldAggRel.getCluster().getTypeFactory();
  RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
  int iAvgInput = oldCall.getArgList().get(0);
  RelDataType avgInputType =
      getFieldType(
          oldAggRel.getInput(),
          iAvgInput);
  RelDataType sumType =
      TypeInferenceUtils.getDrillSqlReturnTypeInference(SqlKind.SUM.name(),
          ImmutableList.of())
        .inferReturnType(oldCall.createBinding(oldAggRel));
  sumType =
      typeFactory.createTypeWithNullability(
          sumType,
          sumType.isNullable() || nGroups == 0);
  SqlAggFunction sumAgg =
      new DrillCalciteSqlAggFunctionWrapper(new SqlSumEmptyIsZeroAggFunction(), sumType);
  AggregateCall sumCall = AggregateCall.create(sumAgg, oldCall.isDistinct(),
      oldCall.isApproximate(), oldCall.getArgList(), -1, sumType, null);
  final SqlCountAggFunction countAgg = (SqlCountAggFunction) SqlStdOperatorTable.COUNT;
  final RelDataType countType = countAgg.getReturnType(typeFactory);
  AggregateCall countCall = AggregateCall.create(countAgg, oldCall.isDistinct(),
      oldCall.isApproximate(), oldCall.getArgList(), -1, countType, null);

  RexNode tmpsumRef =
      rexBuilder.addAggCall(
          sumCall,
          nGroups,
          oldAggRel.indicator,
          newCalls,
          aggCallMapping,
          ImmutableList.of(avgInputType));

  RexNode tmpcountRef =
      rexBuilder.addAggCall(
          countCall,
          nGroups,
          oldAggRel.indicator,
          newCalls,
          aggCallMapping,
          ImmutableList.of(avgInputType));

  RexNode n = rexBuilder.makeCall(SqlStdOperatorTable.CASE,
      rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
          tmpcountRef, rexBuilder.makeExactLiteral(BigDecimal.ZERO)),
          rexBuilder.constantNull(),
          tmpsumRef);

  // NOTE:  these references are with respect to the output
  // of newAggRel
  /*
  RexNode numeratorRef =
      rexBuilder.makeCall(CastHighOp,
        rexBuilder.addAggCall(
            sumCall,
            nGroups,
            newCalls,
            aggCallMapping,
            ImmutableList.of(avgInputType))
      );
  */
  RexNode numeratorRef = rexBuilder.makeCall(CastHighOp,  n);

  RexNode denominatorRef =
      rexBuilder.addAggCall(
          countCall,
          nGroups,
          oldAggRel.indicator,
          newCalls,
          aggCallMapping,
          ImmutableList.of(avgInputType));
  if (isInferenceEnabled) {
    return rexBuilder.makeCall(
        new DrillSqlOperator(
            "divide",
            2,
            true,
            oldCall.getType(), false),
        numeratorRef,
        denominatorRef);
  } else {
    final RexNode divideRef =
        rexBuilder.makeCall(
            SqlStdOperatorTable.DIVIDE,
            numeratorRef,
            denominatorRef);
    return rexBuilder.makeCast(
        typeFactory.createSqlType(SqlTypeName.ANY), divideRef);
  }
}
 
Example 10
Source File: DrillReduceAggregatesRule.java    From Bats with Apache License 2.0 4 votes vote down vote up
private RexNode reduceSum(
    Aggregate oldAggRel,
    AggregateCall oldCall,
    List<AggregateCall> newCalls,
    Map<AggregateCall, RexNode> aggCallMapping) {
  final PlannerSettings plannerSettings = (PlannerSettings) oldAggRel.getCluster().getPlanner().getContext();
  final boolean isInferenceEnabled = plannerSettings.isTypeInferenceEnabled();
  final int nGroups = oldAggRel.getGroupCount();
  RelDataTypeFactory typeFactory =
      oldAggRel.getCluster().getTypeFactory();
  RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
  int arg = oldCall.getArgList().get(0);
  RelDataType argType =
      getFieldType(
          oldAggRel.getInput(),
          arg);
  final RelDataType sumType;
  final SqlAggFunction sumZeroAgg;
  if (isInferenceEnabled) {
    sumType = oldCall.getType();
  } else {
    sumType =
        typeFactory.createTypeWithNullability(
            oldCall.getType(), argType.isNullable());
  }
  sumZeroAgg = new DrillCalciteSqlAggFunctionWrapper(
      new SqlSumEmptyIsZeroAggFunction(), sumType);
  AggregateCall sumZeroCall = AggregateCall.create(sumZeroAgg, oldCall.isDistinct(),
      oldCall.isApproximate(), oldCall.getArgList(), -1, sumType, null);
  final SqlCountAggFunction countAgg = (SqlCountAggFunction) SqlStdOperatorTable.COUNT;
  final RelDataType countType = countAgg.getReturnType(typeFactory);
  AggregateCall countCall = AggregateCall.create(countAgg, oldCall.isDistinct(),
      oldCall.isApproximate(), oldCall.getArgList(), -1, countType, null);
  // NOTE:  these references are with respect to the output
  // of newAggRel
  RexNode sumZeroRef =
      rexBuilder.addAggCall(
          sumZeroCall,
          nGroups,
          oldAggRel.indicator,
          newCalls,
          aggCallMapping,
          ImmutableList.of(argType));
  if (!oldCall.getType().isNullable()) {
    // If SUM(x) is not nullable, the validator must have determined that
    // nulls are impossible (because the group is never empty and x is never
    // null). Therefore we translate to SUM0(x).
    return sumZeroRef;
  }
  RexNode countRef =
      rexBuilder.addAggCall(
          countCall,
          nGroups,
          oldAggRel.indicator,
          newCalls,
          aggCallMapping,
          ImmutableList.of(argType));
  return rexBuilder.makeCall(SqlStdOperatorTable.CASE,
      rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
          countRef, rexBuilder.makeExactLiteral(BigDecimal.ZERO)),
          rexBuilder.constantNull(),
          sumZeroRef);
}
 
Example 11
Source File: AggregateMergeRule.java    From calcite with Apache License 2.0 4 votes vote down vote up
public void onMatch(RelOptRuleCall call) {
  final Aggregate topAgg = call.rel(0);
  final Aggregate bottomAgg = call.rel(1);
  if (topAgg.getGroupCount() > bottomAgg.getGroupCount()) {
    return;
  }

  final ImmutableBitSet bottomGroupSet = bottomAgg.getGroupSet();
  final Map<Integer, Integer> map = new HashMap<>();
  bottomGroupSet.forEach(v -> map.put(map.size(), v));
  for (int k : topAgg.getGroupSet()) {
    if (!map.containsKey(k)) {
      return;
    }
  }

  // top aggregate keys must be subset of lower aggregate keys
  final ImmutableBitSet topGroupSet = topAgg.getGroupSet().permute(map);
  if (!bottomGroupSet.contains(topGroupSet)) {
    return;
  }

  boolean hasEmptyGroup = topAgg.getGroupSets()
      .stream().anyMatch(n -> n.isEmpty());

  final List<AggregateCall> finalCalls = new ArrayList<>();
  for (AggregateCall topCall : topAgg.getAggCallList()) {
    if (!isAggregateSupported(topCall)
        || topCall.getArgList().size() == 0) {
      return;
    }
    // Make sure top aggregate argument refers to one of the aggregate
    int bottomIndex = topCall.getArgList().get(0) - bottomGroupSet.cardinality();
    if (bottomIndex >= bottomAgg.getAggCallList().size()
        || bottomIndex < 0) {
      return;
    }
    AggregateCall bottomCall = bottomAgg.getAggCallList().get(bottomIndex);
    // Should not merge if top agg with empty group keys and the lower agg
    // function is COUNT, because in case of empty input for lower agg,
    // the result is empty, if we merge them, we end up with 1 result with
    // 0, which is wrong.
    if (!isAggregateSupported(bottomCall)
        || (bottomCall.getAggregation() == SqlStdOperatorTable.COUNT
             && hasEmptyGroup)) {
      return;
    }
    SqlSplittableAggFunction splitter = Objects.requireNonNull(
        bottomCall.getAggregation().unwrap(SqlSplittableAggFunction.class));
    AggregateCall finalCall = splitter.merge(topCall, bottomCall);
    // fail to merge the aggregate call, bail out
    if (finalCall == null) {
      return;
    }
    finalCalls.add(finalCall);
  }

  // re-map grouping sets
  ImmutableList<ImmutableBitSet> newGroupingSets = null;
  if (topAgg.getGroupType() != Group.SIMPLE) {
    newGroupingSets =
        ImmutableBitSet.ORDERING.immutableSortedCopy(
            ImmutableBitSet.permute(topAgg.getGroupSets(), map));
  }

  final Aggregate finalAgg =
      topAgg.copy(topAgg.getTraitSet(), bottomAgg.getInput(), topGroupSet,
          newGroupingSets, finalCalls);
  call.transformTo(finalAgg);
}