/* * Copyright (C) 2016-2020 ActionTech. * License: http://www.gnu.org/licenses/gpl.html GPL version 2 or higher. */ package com.actiontech.dble.plan.common.item.function.sumfunc; import com.actiontech.dble.net.mysql.RowDataPacket; import com.actiontech.dble.plan.common.field.Field; import com.actiontech.dble.plan.common.item.Item; import com.alibaba.druid.sql.ast.SQLExpr; import com.alibaba.druid.sql.ast.expr.SQLAggregateExpr; import com.alibaba.druid.sql.ast.expr.SQLAggregateOption; import java.math.BigDecimal; import java.math.BigInteger; import java.math.RoundingMode; import java.util.List; public class ItemSumAvg extends ItemSumSum { private long count; public ItemSumAvg(List<Item> args, boolean distinct, boolean isPushDown, List<Field> fields) { super(args, distinct, isPushDown, fields); count = 0; } @Override public void fixLengthAndDec() { super.fixLengthAndDec(); maybeNull = nullValue = true; } @Override public SumFuncType sumType() { return hasWithDistinct() ? SumFuncType.AVG_DISTINCT_FUNC : SumFuncType.AVG_FUNC; } @Override public void clear() { super.clear(); count = 0; } @Override public Object getTransAggObj() { AvgAggData aggData = new AvgAggData(sum, count, nullValue); return aggData; } @Override public int getTransSize() { return 20; } @Override public boolean add(RowDataPacket row, Object transObj) { if (transObj != null) { AvgAggData data = (AvgAggData) transObj; if (super.add(row, data)) return true; if (!data.isNull) count += data.count; } else { if (super.add(row, null)) return true; if (!aggr.argIsNull()) count++; } return false; } @Override public boolean pushDownAdd(RowDataPacket row) { // avg(n) will be push down as sum(n) and count(n); assert (getArgCount() == 2); count += args.get(1).valInt().longValue(); return super.add(row, null); } @Override public BigDecimal valReal() { if (aggr != null) aggr.endup(); if (count == 0) { nullValue = true; return BigDecimal.ZERO; } return super.valReal().divide(new BigDecimal(count), decimals + 4, RoundingMode.HALF_UP); } @Override public BigInteger valInt() { return valReal().toBigInteger(); } @Override public BigDecimal valDecimal() { if (aggr != null) aggr.endup(); if (count == 0) { nullValue = true; return null; } return valReal(); } @Override public String valStr() { if (aggr != null) aggr.endup(); if (hybridType == ItemResult.DECIMAL_RESULT) return valStringFromDecimal(); return valStringFromReal(); } @Override public void cleanup() { count = 0; super.cleanup(); } @Override public final String funcName() { return "AVG"; } @Override public void noRowsInResult() { } @Override public SQLExpr toExpression() { Item arg0 = args.get(0); SQLAggregateExpr aggregate = new SQLAggregateExpr(funcName()); aggregate.addArgument(arg0.toExpression()); if (hasWithDistinct()) { aggregate.setOption(SQLAggregateOption.DISTINCT); } return aggregate; } @Override protected Item cloneStruct(boolean forCalculate, List<Item> calArgs, boolean isPushDown, List<Field> fields) { if (!forCalculate) { List<Item> newArgs = cloneStructList(args); return new ItemSumAvg(newArgs, hasWithDistinct(), false, null); } else { return new ItemSumAvg(calArgs, hasWithDistinct(), isPushDown, fields); } } private static class AvgAggData extends AggData { private static final long serialVersionUID = -1831762635995954526L; private long count; AvgAggData(BigDecimal sum, long count, boolean isNull) { super(sum, isNull); this.count = count; } } }