package com.github.aaronshan.functions.regexp;

import io.airlift.slice.Slices;
import java.util.ArrayList;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;

/**
 * @author ruifeng.shan
 * date: 2018-07-27
 * time: 22:38
 */
@Description(name = "regexp_extract_all"
        , value = "_FUNC_(string, string) - string(s) extracted using the given pattern\n" +
        "_FUNC_(string, string, long) - group(s) extracted using the given pattern."
        , extended = "Example:\n > select _FUNC_(string, pattern) from src;")
public class UDFRe2JRegexpExtractAll extends GenericUDF {
    private transient ArrayList<Object> result = new ArrayList<Object>();
    private transient Re2JRegexp re2JRegexp;

    public UDFRe2JRegexpExtractAll() {

    }

    @Override
    public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
        // Check if two arguments were passed
        if (arguments.length != 2 && arguments.length != 3) {
            throw new UDFArgumentLengthException(
                    "The function regexp_extract_all takes exactly 2 or 3 arguments.");
        }

        for (int i = 0; i < 2; i++) {
            if (!ObjectInspectorUtils.compareTypes(PrimitiveObjectInspectorFactory.javaStringObjectInspector, arguments[i])) {
                throw new UDFArgumentTypeException(i,
                        "\"" + PrimitiveObjectInspectorFactory.javaStringObjectInspector.getTypeName() + "\" "
                                + "expected at function regexp_extract_all, but "
                                + "\"" + arguments[i].getTypeName() + "\" "
                                + "is found");
            }
        }

        if (arguments.length == 3) {
            if (!ObjectInspectorUtils.compareTypes(PrimitiveObjectInspectorFactory.javaLongObjectInspector, arguments[2])) {
                throw new UDFArgumentTypeException(2,
                        "\"" + PrimitiveObjectInspectorFactory.javaLongObjectInspector.getTypeName() + "\" "
                                + "expected at function regexp_extract_all, but "
                                + "\"" + arguments[2].getTypeName() + "\" "
                                + "is found");
            }
        }

        ObjectInspector expectOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;

        return ObjectInspectorFactory.getStandardListObjectInspector(expectOI);
    }

    @Override
    public Object evaluate(DeferredObject[] arguments) throws HiveException {
        String source = (String) arguments[0].get();
        String pattern = (String) arguments[1].get();
        Long groupIndex = 0L;
        if (arguments.length == 3) {
            groupIndex = (Long) arguments[2].get();
        }

        if (source == null) {
            return null;
        }

        if (re2JRegexp == null) {
            re2JRegexp = new Re2JRegexp(Integer.MAX_VALUE, 5, Slices.utf8Slice(pattern));
        }

        result.clear();
        result.addAll(re2JRegexp.extractAll(Slices.utf8Slice(source), groupIndex));

        return result;
    }

    @Override
    public String getDisplayString(String[] strings) {
        assert (strings.length == 2 || strings.length == 3);
        if (strings.length == 2) {
            return "regexp_extract_all(" + strings[0] + ", "
                    + strings[1] + ")";
        } else {
            return "regexp_extract_all(" + strings[0] + ", "
                    + strings[1] + ", " + strings[2] + ")";
        }
    }
}