/*
 * Copyright 2016 Yahoo Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.yahoo.hive.udf.funnel;

import java.util.List;
import java.util.stream.Collectors;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFParameterInfo;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;

@UDFType(deterministic = true)
@Description(name = "merge_funnel",
             value = "_FUNC_(funnel_column) - Merges funnels. Use with funnel UDF.",
             extended = "Example: SELECT merge_funnel(funnel)\n" +
                        "         FROM (SELECT funnel(action, timestamp, array('signup_page', 'email_signup'), \n" +
                        "                                                array('confirm_button'),\n" +
                        "                                                array('submit_button')) AS funnel\n" +
                        "               FROM table\n" +
                        "               GROUP BY user_id) t;")
public class Merge extends AbstractGenericUDAFResolver {
    static final Log LOG = LogFactory.getLog(Merge.class.getName());

    @Override
    public MergeEvaluator getEvaluator(GenericUDAFParameterInfo info) throws SemanticException {
        // Get the parameters
        TypeInfo [] parameters = info.getParameters();

        // Check number of arguments
        if (parameters.length != 1) {
            throw new UDFArgumentLengthException("Please specify the funnel column.");
        }

        // Check if the parameter is not a list
        if (parameters[0].getCategory() != ObjectInspector.Category.LIST) {
            throw new UDFArgumentTypeException(0, "Only list type arguments are accepted but " + parameters[0].getTypeName() + " was passed as the first parameter.");
        }

        // Check that the list is an array of primitives
        if (((ListTypeInfo) parameters[0]).getListElementTypeInfo().getCategory() != ObjectInspector.Category.PRIMITIVE) {
            throw new UDFArgumentTypeException(0, "A long array argument should be passed, but " + parameters[0].getTypeName() + " was passed instead.");
        }

        // Check that the list is of type long
        // May want to add support for int/double/float later
        switch (((PrimitiveTypeInfo) ((ListTypeInfo) parameters[0]).getListElementTypeInfo()).getPrimitiveCategory()) {
            case LONG:
                break;
            default:
                throw new UDFArgumentTypeException(0, "A long array argument should be passed, but " + parameters[0].getTypeName() + " was passed instead.");
        }

        return new MergeEvaluator();
    }

    public static class MergeEvaluator extends GenericUDAFEvaluator {
        /** Input list object inspector. Used during iterate. */
        private ListObjectInspector listObjectInspector;

        /** Long object inspector. Used during merge. */
        private LongObjectInspector longObjectInspector;

        @Override
        public ObjectInspector init(Mode mode, ObjectInspector[] parameters) throws HiveException {
            super.init(mode, parameters);

            // Setup the list and element object inspectors.
            listObjectInspector = (ListObjectInspector) parameters[0];
            longObjectInspector = (LongObjectInspector) listObjectInspector.getListElementObjectInspector();

            // Will return a list of longs
            return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaLongObjectInspector);
        }

        @Override
        public AggregationBuffer getNewAggregationBuffer() throws HiveException {
            return new MergeAggregateBuffer();
        }

        @Override
        public void iterate(AggregationBuffer aggregate, Object[] parameters) throws HiveException {
            Object parameter = parameters[0];
            // If not null
            if (parameter != null) {
                // Get the funnel aggregate and the funnel data
                MergeAggregateBuffer funnelAggregate = (MergeAggregateBuffer) aggregate;
                List<Long> funnel = (List<Long>) listObjectInspector.getList(parameter);

                // Add the funnel to the funnel aggregate
                funnelAggregate.addFunnel(funnel);
            }
        }

        @Override
        public void merge(AggregationBuffer aggregate, Object partial) throws HiveException {
            if (partial != null) {

                // Get the funnel aggregate and the funnel data
                MergeAggregateBuffer funnelAggregate = (MergeAggregateBuffer) aggregate;

                // Convert the partial results into a list of longs
                List<Long> funnel = listObjectInspector.getList(partial)
                                                       .stream()
                                                       .map(longObjectInspector::get)
                                                       .collect(Collectors.toList());

                // Add the funnel to the funnel aggregate
                funnelAggregate.addFunnel(funnel);
            }
        }

        @Override
        public void reset(AggregationBuffer aggregate) throws HiveException {
            MergeAggregateBuffer funnelAggregate = (MergeAggregateBuffer) aggregate;
            funnelAggregate.clear();
        }

        @Override
        public Object terminate(AggregationBuffer aggregate) throws HiveException {
            MergeAggregateBuffer funnelAggregate = (MergeAggregateBuffer) aggregate;
            return funnelAggregate.output();
        }

        @Override
        public Object terminatePartial(AggregationBuffer aggregate) throws HiveException {
            MergeAggregateBuffer funnelAggregate = (MergeAggregateBuffer) aggregate;
            return funnelAggregate.output();
        }
    }
}