/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.datasketches.hive.theta; import static org.apache.datasketches.Util.DEFAULT_UPDATE_SEED; import java.util.Arrays; import org.apache.datasketches.hive.common.BytesWritableHelper; import org.apache.datasketches.memory.Memory; import org.apache.datasketches.theta.Intersection; import org.apache.datasketches.theta.SetOperation; import org.apache.datasketches.theta.Sketch; import org.apache.datasketches.theta.Sketches; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.parse.SemanticException; import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFParameterInfo; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; import org.apache.hadoop.hive.serde2.objectinspector.StandardStructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; import org.apache.hadoop.io.BytesWritable; import org.apache.hadoop.io.LongWritable; @Description( name = "intersectSketch", value = "_FUNC_(sketch, seed) - Compute the intersection of sketches", extended = "Example:\n" + "> SELECT intersectSketch(sketch) FROM src;\n" + "The return value is a binary blob that contains a compact sketch, which can " + "be operated on by the other sketch-related functions. " + "The seed is optional, " + "and using it is not recommended unless you really know why you need it.") @SuppressWarnings("javadoc") public class IntersectSketchUDAF extends AbstractGenericUDAFResolver { @Override public GenericUDAFEvaluator getEvaluator(final GenericUDAFParameterInfo info) throws SemanticException { final ObjectInspector[] inspectors = info.getParameterObjectInspectors(); if (inspectors.length < 1) { throw new UDFArgumentException("Please specify at least 1 argument"); } if (inspectors.length > 2) { throw new UDFArgumentTypeException(inspectors.length - 1, "Please specify no more than 2 arguments"); } ObjectInspectorValidator.validateGivenPrimitiveCategory(inspectors[0], 0, PrimitiveCategory.BINARY); if (inspectors.length > 1) { ObjectInspectorValidator.validateIntegralParameter(inspectors[1], 1); } return new IntersectSketchUDAFEvaluator(); } static class IntersectSketchUDAFEvaluator extends GenericUDAFEvaluator { protected static final String SEED_FIELD = "seed"; protected static final String SKETCH_FIELD = "sketch"; // FOR PARTIAL1 and COMPLETE modes: ObjectInspectors for original data private transient PrimitiveObjectInspector inputObjectInspector; protected transient PrimitiveObjectInspector seedObjectInspector; // FOR PARTIAL2 and FINAL modes: ObjectInspectors for partial aggregations protected transient StructObjectInspector intermediateObjectInspector; @Override public ObjectInspector init(final Mode mode, final ObjectInspector[] parameters) throws HiveException { super.init(mode, parameters); if ((mode == Mode.PARTIAL1) || (mode == Mode.COMPLETE)) { inputObjectInspector = (PrimitiveObjectInspector) parameters[0]; if (parameters.length > 1) { seedObjectInspector = (PrimitiveObjectInspector) parameters[1]; } } else { intermediateObjectInspector = (StandardStructObjectInspector) parameters[0]; } if ((mode == Mode.PARTIAL1) || (mode == Mode.PARTIAL2)) { // intermediate results need to include the seed return ObjectInspectorFactory.getStandardStructObjectInspector( Arrays.asList(SEED_FIELD, SKETCH_FIELD), Arrays.asList( PrimitiveObjectInspectorFactory .getPrimitiveWritableObjectInspector(PrimitiveCategory.LONG), PrimitiveObjectInspectorFactory .getPrimitiveWritableObjectInspector(PrimitiveCategory.BINARY) ) ); } // final results include just the sketch return PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(PrimitiveCategory.BINARY); } @Override public void iterate(final @SuppressWarnings("deprecation") AggregationBuffer buf, final Object[] data) throws HiveException { if (data[0] == null) { return; } final IntersectionState state = (IntersectionState) buf; if (!state.isInitialized()) { long seed = DEFAULT_UPDATE_SEED; if (seedObjectInspector != null) { seed = PrimitiveObjectInspectorUtils.getLong(data[1], seedObjectInspector); } state.init(seed); } final byte[] serializedSketch = (byte[]) inputObjectInspector.getPrimitiveJavaObject(data[0]); if (serializedSketch == null) { return; } state.update(Memory.wrap(serializedSketch)); } @Override public Object terminatePartial(final @SuppressWarnings("deprecation") AggregationBuffer buf) throws HiveException { final IntersectionState state = (IntersectionState) buf; final Sketch intermediate = state.getResult(); if (intermediate == null) { return null; } final byte[] bytes = intermediate.toByteArray(); return Arrays.asList( new LongWritable(state.getSeed()), new BytesWritable(bytes) ); } @Override public void merge(final @SuppressWarnings("deprecation") AggregationBuffer buf, final Object data) throws HiveException { if (data == null) { return; } final IntersectionState state = (IntersectionState) buf; if (!state.isInitialized()) { final long seed = ((LongWritable) intermediateObjectInspector.getStructFieldData( data, intermediateObjectInspector.getStructFieldRef(SEED_FIELD))).get(); state.init(seed); } final Memory serializedSketch = BytesWritableHelper.wrapAsMemory( (BytesWritable) intermediateObjectInspector.getStructFieldData( data, intermediateObjectInspector.getStructFieldRef(SKETCH_FIELD))); state.update(serializedSketch); } @Override public Object terminate(final @SuppressWarnings("deprecation") AggregationBuffer buf) throws HiveException { final IntersectionState state = (IntersectionState) buf; final Sketch resultSketch = state.getResult(); if (resultSketch == null) { return null; } return new BytesWritable(resultSketch.toByteArray()); } @SuppressWarnings("deprecation") @Override public AggregationBuffer getNewAggregationBuffer() throws HiveException { return new IntersectionState(); } @Override public void reset(final @SuppressWarnings("deprecation") AggregationBuffer buf) throws HiveException { final IntersectionState state = (IntersectionState) buf; state.reset(); } static class IntersectionState extends AbstractAggregationBuffer { private long seed_; private Intersection intersection_; boolean isInitialized() { return intersection_ != null; } void init(final long seed) { seed_ = seed; intersection_ = SetOperation.builder().setSeed(seed).buildIntersection(); } long getSeed() { return seed_; } void update(final Memory serializedSketch) { intersection_.update(Sketches.wrapSketch(serializedSketch, seed_)); } Sketch getResult() { if (intersection_ == null) { return null; } return intersection_.getResult(); } void reset() { intersection_ = null; } } } }