/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package hivemall.tools.map; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.lang.Preconditions; import java.util.HashMap; import java.util.Map; import javax.annotation.Nonnull; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.parse.SemanticException; import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; //@formatter:off @Description(name = "merge_maps", value = "_FUNC_(Map x) - Returns a map which contains the union of an aggregation of maps." + " Note that an existing value of a key can be replaced with the other duplicate key entry.", extended = "SELECT \n" + " merge_maps(m) \n" + "FROM (\n" + " SELECT map('A',10,'B',20,'C',30) \n" + " UNION ALL \n" + " SELECT map('A',10,'B',20,'C',30)\n" + ") t") //@formatter:on public final class MergeMapsUDAF extends AbstractGenericUDAFResolver { @Override public MergeMapsEvaluator getEvaluator(TypeInfo[] types) throws SemanticException { if (types.length != 1) { throw new UDFArgumentTypeException(types.length - 1, "One argument is expected but got " + types.length); } TypeInfo paramType = types[0]; if (paramType.getCategory() != Category.MAP) { throw new UDFArgumentTypeException(0, "Only maps supported for now "); } return new MergeMapsEvaluator(); } public static final class MergeMapsEvaluator extends GenericUDAFEvaluator { private transient MapObjectInspector inputMapOI, mergeMapOI; private transient ObjectInspector inputKeyOI, inputValOI; @AggregationType(estimable = false) static final class MapAggBuffer extends AbstractAggregationBuffer { @Nonnull final Map<Object, Object> collectMap = new HashMap<Object, Object>(); } public ObjectInspector init(Mode mode, ObjectInspector[] parameters) throws HiveException { Preconditions.checkArgument(parameters.length == 1); super.init(mode, parameters); // initialize input if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data this.inputMapOI = HiveUtils.asMapOI(parameters[0]); this.inputKeyOI = inputMapOI.getMapKeyObjectInspector(); this.inputValOI = inputMapOI.getMapValueObjectInspector(); } else {// from partial aggregation this.mergeMapOI = HiveUtils.asMapOI(parameters[0]); this.inputKeyOI = mergeMapOI.getMapKeyObjectInspector(); this.inputValOI = mergeMapOI.getMapValueObjectInspector(); } return ObjectInspectorFactory.getStandardMapObjectInspector( ObjectInspectorUtils.getStandardObjectInspector(inputKeyOI), ObjectInspectorUtils.getStandardObjectInspector(inputValOI)); } @Override public MapAggBuffer getNewAggregationBuffer() throws HiveException { MapAggBuffer buff = new MapAggBuffer(); reset(buff); return buff; } @Override public void reset(@SuppressWarnings("deprecation") AggregationBuffer buff) throws HiveException { MapAggBuffer aggrBuf = (MapAggBuffer) buff; aggrBuf.collectMap.clear(); } @Override public void iterate(@SuppressWarnings("deprecation") AggregationBuffer agg, Object[] parameters) throws HiveException { Preconditions.checkArgument(parameters.length == 1); Object param0 = parameters[0]; if (param0 == null) { return; } Map<?, ?> m = inputMapOI.getMap(param0); MapAggBuffer myagg = (MapAggBuffer) agg; putIntoSet(m, myagg.collectMap, inputMapOI); } @Override public void merge(@SuppressWarnings("deprecation") AggregationBuffer agg, Object partial) throws HiveException { if (partial == null) { return; } MapAggBuffer myagg = (MapAggBuffer) agg; Map<?, ?> m = mergeMapOI.getMap(partial); putIntoSet(m, myagg.collectMap, mergeMapOI); } private static void putIntoSet(@Nonnull final Map<?, ?> m, @Nonnull final Map<Object, Object> dst, @Nonnull final MapObjectInspector mapOI) { final ObjectInspector keyOI = mapOI.getMapKeyObjectInspector(); final ObjectInspector valueOI = mapOI.getMapValueObjectInspector(); for (Map.Entry<?, ?> e : m.entrySet()) { Object k = e.getKey(); Object v = e.getValue(); Object keyCopy = ObjectInspectorUtils.copyToStandardObject(k, keyOI); Object valCopy = ObjectInspectorUtils.copyToStandardObject(v, valueOI); dst.put(keyCopy, valCopy); } } @Override @Nonnull public Map<Object, Object> terminatePartial( @SuppressWarnings("deprecation") AggregationBuffer agg) throws HiveException { MapAggBuffer myagg = (MapAggBuffer) agg; return myagg.collectMap; } @Override public Object terminate(@SuppressWarnings("deprecation") AggregationBuffer agg) throws HiveException { MapAggBuffer myagg = (MapAggBuffer) agg; return myagg.collectMap; } } }