/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package hivemall.tools.map;

import hivemall.utils.collections.maps.BoundedSortedMap;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.Preconditions;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;

import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFParameterInfo;
import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.IntWritable;

/**
 * Convert two aggregated columns into a sorted key-value map.
 */
@Description(name = "to_ordered_map",
        value = "_FUNC_(key, value [, const int k|const boolean reverseOrder=false]) "
                + "- Convert two aggregated columns into an ordered key-value map",
        extended = "with t as (\n" + "    select 10 as key, 'apple' as value\n" + "    union all\n"
                + "    select 3 as key, 'banana' as value\n" + "    union all\n"
                + "    select 4 as key, 'candy' as value\n" + ")\n" + "select\n"
                + "    to_ordered_map(key, value, true),   -- {10:\"apple\",4:\"candy\",3:\"banana\"} (reverse)\n"
                + "    to_ordered_map(key, value, 1),      -- {10:\"apple\"} (top-1)\n"
                + "    to_ordered_map(key, value, 2),      -- {10:\"apple\",4:\"candy\"} (top-2)\n"
                + "    to_ordered_map(key, value, 3),      -- {10:\"apple\",4:\"candy\",3:\"banana\"} (top-3)\n"
                + "    to_ordered_map(key, value, 100),    -- {10:\"apple\",4:\"candy\",3:\"banana\"} (top-100)\n"
                + "    to_ordered_map(key, value),         -- {3:\"banana\",4:\"candy\",10:\"apple\"} (natural)\n"
                + "    to_ordered_map(key, value, -1),     -- {3:\"banana\"} (tail-1)\n"
                + "    to_ordered_map(key, value, -2),     -- {3:\"banana\",4:\"candy\"} (tail-2)\n"
                + "    to_ordered_map(key, value, -3),     -- {3:\"banana\",4:\"candy\",10:\"apple\"} (tail-3)\n"
                + "    to_ordered_map(key, value, -100)    -- {3:\"banana\",4:\"candy\",10:\"apple\"} (tail-100)\n"
                + "from t")
public final class UDAFToOrderedMap extends UDAFToMap {

    @Override
    public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo info)
            throws SemanticException {
        @SuppressWarnings("deprecation")
        final TypeInfo[] typeInfo = info.getParameters();
        if (typeInfo.length != 2 && typeInfo.length != 3) {
            throw new UDFArgumentTypeException(typeInfo.length - 1,
                "Expecting two or three arguments: " + typeInfo.length);
        }
        if (typeInfo[0].getCategory() != ObjectInspector.Category.PRIMITIVE) {
            throw new UDFArgumentTypeException(0,
                "Only primitive type arguments are accepted for the key but "
                        + typeInfo[0].getTypeName() + " was passed as parameter 1.");
        }

        boolean reverseOrder = false;
        int size = 0;
        if (typeInfo.length == 3) {
            ObjectInspector[] argOIs = info.getParameterObjectInspectors();
            ObjectInspector argOI2 = argOIs[2];
            if (HiveUtils.isConstBoolean(argOI2)) {
                reverseOrder = HiveUtils.getConstBoolean(argOI2);
            } else if (HiveUtils.isConstInteger(argOI2)) {
                size = HiveUtils.getConstInt(argOI2);
                if (size == 0) {
                    throw new UDFArgumentException("Map size must be non-zero value: " + size);
                }
                reverseOrder = (size > 0); // positive size => top-k
            } else {
                throw new UDFArgumentTypeException(2,
                    "The third argument must be boolean or int type: " + typeInfo[2].getTypeName());
            }
        }

        if (reverseOrder) { // descending
            if (size == 0) {
                return new ReverseOrderedMapEvaluator();
            } else {
                return new TopKOrderedMapEvaluator();
            }
        } else { // ascending
            if (size == 0) {
                return new NaturalOrderedMapEvaluator();
            } else {
                return new TailKOrderedMapEvaluator();
            }
        }
    }

    public static class NaturalOrderedMapEvaluator extends UDAFToMapEvaluator {

        @Override
        public void reset(@SuppressWarnings("deprecation") AggregationBuffer agg)
                throws HiveException {
            ((MapAggregationBuffer) agg).container = new TreeMap<Object, Object>();
        }

    }

    public static class ReverseOrderedMapEvaluator extends UDAFToMapEvaluator {

        @Override
        public void reset(@SuppressWarnings("deprecation") AggregationBuffer agg)
                throws HiveException {
            ((MapAggregationBuffer) agg).container =
                    new TreeMap<Object, Object>(Collections.reverseOrder());
        }

    }

    public static class TopKOrderedMapEvaluator extends GenericUDAFEvaluator {

        protected PrimitiveObjectInspector inputKeyOI;
        protected ObjectInspector inputValueOI;
        protected MapObjectInspector partialMapOI;
        protected PrimitiveObjectInspector sizeOI;

        protected StructObjectInspector internalMergeOI;

        protected StructField partialMapField;
        protected StructField sizeField;

        @Override
        public ObjectInspector init(Mode mode, ObjectInspector[] argOIs) throws HiveException {
            super.init(mode, argOIs);

            // initialize input
            if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data
                this.inputKeyOI = HiveUtils.asPrimitiveObjectInspector(argOIs[0]);
                this.inputValueOI = argOIs[1];
                this.sizeOI = HiveUtils.asIntegerOI(argOIs[2]);
            } else {// from partial aggregation
                StructObjectInspector soi = (StructObjectInspector) argOIs[0];
                this.internalMergeOI = soi;

                this.partialMapField = soi.getStructFieldRef("partialMap");
                // re-extract input key/value OIs
                MapObjectInspector partialMapOI =
                        (MapObjectInspector) partialMapField.getFieldObjectInspector();
                this.inputKeyOI = HiveUtils.asPrimitiveObjectInspector(
                    partialMapOI.getMapKeyObjectInspector());
                this.inputValueOI = partialMapOI.getMapValueObjectInspector();

                this.partialMapOI = ObjectInspectorFactory.getStandardMapObjectInspector(
                    ObjectInspectorUtils.getStandardObjectInspector(inputKeyOI),
                    ObjectInspectorUtils.getStandardObjectInspector(inputValueOI));

                this.sizeField = soi.getStructFieldRef("size");
                this.sizeOI = (PrimitiveObjectInspector) sizeField.getFieldObjectInspector();
            }

            // initialize output
            final ObjectInspector outputOI;
            if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// terminatePartial
                outputOI = internalMergeOI(inputKeyOI, inputValueOI);
            } else {// terminate
                outputOI = ObjectInspectorFactory.getStandardMapObjectInspector(
                    ObjectInspectorUtils.getStandardObjectInspector(inputKeyOI),
                    ObjectInspectorUtils.getStandardObjectInspector(inputValueOI));
            }
            return outputOI;
        }

        @Nonnull
        private static StructObjectInspector internalMergeOI(
                @Nonnull PrimitiveObjectInspector keyOI, @Nonnull ObjectInspector valueOI) {
            List<String> fieldNames = new ArrayList<String>();
            List<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();

            fieldNames.add("partialMap");
            fieldOIs.add(ObjectInspectorFactory.getStandardMapObjectInspector(
                ObjectInspectorUtils.getStandardObjectInspector(keyOI),
                ObjectInspectorUtils.getStandardObjectInspector(valueOI)));

            fieldNames.add("size");
            fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);

            return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
        }

        static class MapAggregationBuffer extends AbstractAggregationBuffer {
            @Nullable
            Map<Object, Object> container;
            int size;

            MapAggregationBuffer() {
                super();
            }
        }

        @Override
        public void reset(@SuppressWarnings("deprecation") AggregationBuffer agg)
                throws HiveException {
            MapAggregationBuffer myagg = (MapAggregationBuffer) agg;
            myagg.container = null;
            myagg.size = 0;
        }

        @Override
        public MapAggregationBuffer getNewAggregationBuffer() throws HiveException {
            MapAggregationBuffer myagg = new MapAggregationBuffer();
            reset(myagg);
            return myagg;
        }

        @Override
        public void iterate(@SuppressWarnings("deprecation") AggregationBuffer agg,
                Object[] parameters) throws HiveException {
            assert (parameters.length == 3);
            if (parameters[0] == null) {
                return;
            }

            Object key = ObjectInspectorUtils.copyToStandardObject(parameters[0], inputKeyOI);
            Object value = ObjectInspectorUtils.copyToStandardObject(parameters[1], inputValueOI);
            int size = Math.abs(HiveUtils.getInt(parameters[2], sizeOI)); // size could be negative for tail-k

            MapAggregationBuffer myagg = (MapAggregationBuffer) agg;
            if (myagg.container == null) {
                initBuffer(myagg, size);
            }
            myagg.container.put(key, value);
        }

        void initBuffer(@Nonnull MapAggregationBuffer agg, @Nonnegative int size) {
            Preconditions.checkArgument(size > 0, "size MUST be greater than zero: " + size);

            agg.container = new BoundedSortedMap<Object, Object>(size, true);
            agg.size = size;
        }

        @Override
        public Object terminatePartial(@SuppressWarnings("deprecation") AggregationBuffer agg)
                throws HiveException {
            MapAggregationBuffer myagg = (MapAggregationBuffer) agg;

            Object[] partialResult = new Object[2];
            partialResult[0] = myagg.container;
            partialResult[1] = new IntWritable(myagg.size);

            return partialResult;
        }

        @Override
        public void merge(@SuppressWarnings("deprecation") AggregationBuffer agg, Object partial)
                throws HiveException {
            if (partial == null) {
                return;
            }

            MapAggregationBuffer myagg = (MapAggregationBuffer) agg;

            Object partialMapObj = internalMergeOI.getStructFieldData(partial, partialMapField);
            Map<?, ?> partialMap =
                    partialMapOI.getMap(HiveUtils.castLazyBinaryObject(partialMapObj));
            if (partialMap == null) {
                return;
            }

            if (myagg.container == null) {
                Object sizeObj = internalMergeOI.getStructFieldData(partial, sizeField);
                int size = HiveUtils.getInt(sizeObj, sizeOI);
                initBuffer(myagg, size);
            }
            for (Map.Entry<?, ?> e : partialMap.entrySet()) {
                Object key = ObjectInspectorUtils.copyToStandardObject(e.getKey(), inputKeyOI);
                Object value =
                        ObjectInspectorUtils.copyToStandardObject(e.getValue(), inputValueOI);
                myagg.container.put(key, value);
            }
        }

        @Override
        @Nullable
        public Map<Object, Object> terminate(@SuppressWarnings("deprecation") AggregationBuffer agg)
                throws HiveException {
            MapAggregationBuffer myagg = (MapAggregationBuffer) agg;
            return myagg.container;
        }

    }

    public static class TailKOrderedMapEvaluator extends TopKOrderedMapEvaluator {

        @Override
        void initBuffer(MapAggregationBuffer agg, int size) {
            agg.container = new BoundedSortedMap<Object, Object>(size);
            agg.size = size;
        }
    }

}