package com.github.aaronshan.functions.array; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters.Converter; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.io.Text; /** * @author ruifeng.shan * date: 2016-07-26 * time: 17:31 */ @Description(name = "array_join" , value = "_FUNC_(array<E>, delimiter, null_replacement) - concatenates the elements of the given array using the delimiter and an optional null_replacement to replace nulls." , extended = "Example:\n > select _FUNC_(array, delimiter) from src;\n> select _FUNC_(array, delimiter, null_replacement) from src;") public class UDFArrayJoin extends GenericUDF { private static final int ARRAY_IDX = 0; private static final int DELIMITER_IDX = 1; private static final int NULL_REPLACE_IDX = 2; private static final int MIN_ARG_COUNT = 2; // min Number of arguments to this UDF private static final int MAX_ARG_COUNT = 3; // max Number of arguments to this UDF private transient ObjectInspector delimiterOI; private transient ObjectInspector nullReplaceOI; private transient ListObjectInspector arrayOI; private transient ObjectInspector arrayElementOI; private transient Converter delimiterConvert; private transient Converter nullReplaceConvert; private Text result; public UDFArrayJoin() { } @Override public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException { // Check if two arguments were passed if (arguments.length > MAX_ARG_COUNT || arguments.length < MIN_ARG_COUNT) { throw new UDFArgumentLengthException( "The function array_join(array, delimiter) or array_join(array, delimiter, null_replacement) takes exactly " + MIN_ARG_COUNT + " or " + MAX_ARG_COUNT + " arguments."); } // Check if ARRAY_IDX argument is of category LIST if (!arguments[ARRAY_IDX].getCategory().equals(ObjectInspector.Category.LIST)) { throw new UDFArgumentTypeException(ARRAY_IDX, "\"" + org.apache.hadoop.hive.serde.serdeConstants.LIST_TYPE_NAME + "\" " + "expected at function array_join, but " + "\"" + arguments[ARRAY_IDX].getTypeName() + "\" " + "is found"); } arrayOI = (ListObjectInspector) arguments[ARRAY_IDX]; arrayElementOI = arrayOI.getListElementObjectInspector(); delimiterOI = arguments[DELIMITER_IDX]; delimiterConvert = ObjectInspectorConverters.getConverter(delimiterOI, PrimitiveObjectInspectorFactory.writableStringObjectInspector); if (arguments.length == MAX_ARG_COUNT) { nullReplaceOI = arguments[NULL_REPLACE_IDX]; nullReplaceConvert = ObjectInspectorConverters.getConverter(nullReplaceOI, PrimitiveObjectInspectorFactory.writableStringObjectInspector); } result = new Text(); return PrimitiveObjectInspectorFactory.writableStringObjectInspector; } @Override public Object evaluate(DeferredObject[] arguments) throws HiveException { result.set(""); Object array = arguments[ARRAY_IDX].get(); Object delimiter = arguments[DELIMITER_IDX].get(); Object nullReplace = null; if (arguments.length == MAX_ARG_COUNT) { nullReplace = arguments[NULL_REPLACE_IDX].get(); } int arrayLength = arrayOI.getListLength(array); // Check if array is null or empty or value is null if (array == null || arrayLength <= 0) { return result; } StringBuffer stringBuffer = new StringBuffer(); Object listElement = arrayOI.getListElement(array, 0); appendElement(stringBuffer, listElement, nullReplace); for (int i = 1; i < arrayLength; ++i) { stringBuffer.append(delimiterConvert.convert(delimiter)); listElement = arrayOI.getListElement(array, i); appendElement(stringBuffer, listElement, nullReplace); } result.set(stringBuffer.toString()); return result; } private void appendElement(StringBuffer stringBuffer, Object listElement, Object nullReplace) { if (listElement == null) { if (nullReplace != null) { stringBuffer.append(nullReplaceConvert.convert(nullReplace)); } else { stringBuffer.append(listElement); } } else { stringBuffer.append(listElement); } } @Override public String getDisplayString(String[] strings) { assert (strings.length <= MAX_ARG_COUNT || strings.length >= MIN_ARG_COUNT); if (strings.length == 2) { return "array_join(" + strings[ARRAY_IDX] + ", " + strings[DELIMITER_IDX] + ")"; } else { return "array_join(" + strings[ARRAY_IDX] + ", " + strings[DELIMITER_IDX] + ", " + strings[NULL_REPLACE_IDX] + ")"; } } }