package com.packt.sfjd.ch8;

import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

public class AverageUDAF extends UserDefinedAggregateFunction  {
	private static final long serialVersionUID = 1L;

	@Override
	public StructType inputSchema() {
		return new StructType(new StructField[] { new StructField("counter", DataTypes.DoubleType, true, Metadata.empty())});
	}
	
	@Override
	public DataType dataType() {
		return DataTypes.DoubleType;
	}
	
	@Override
	public boolean deterministic() {
		return false;
	}

	
	
	@Override
	public StructType bufferSchema() {		
		return new StructType() .add("sumVal", DataTypes.DoubleType) .add("countVal", DataTypes.DoubleType);
	}

	@Override
	public void initialize(MutableAggregationBuffer bufferAgg) {
		bufferAgg.update(0, 0.0);
		bufferAgg.update(1, 0.0);		
	}

	@Override
	public void update(MutableAggregationBuffer bufferAgg, Row row) {
		bufferAgg.update(0, bufferAgg.getDouble(0)+row.getDouble(0));
		bufferAgg.update(1, bufferAgg.getDouble(1)+2.0);
	}
	
	
	@Override
	public void merge(MutableAggregationBuffer bufferAgg, Row row) {
		bufferAgg.update(0, bufferAgg.getDouble(0)+row.getDouble(0));
		bufferAgg.update(1, bufferAgg.getDouble(1)+row.getDouble(1));		
	}

	
	@Override
	public Object evaluate(Row row) {		
		return row.getDouble(0)/row.getDouble(1);
	}
}