/** * Copyright 2019 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ package com.linkedin.transport.codegen; import com.google.common.collect.ImmutableList; import com.linkedin.transport.api.udf.StdUDF; import com.linkedin.transport.api.udf.TopLevelStdUDF; import com.linkedin.transport.compile.TransportUDFMetadata; import com.squareup.javapoet.ClassName; import com.squareup.javapoet.JavaFile; import com.squareup.javapoet.MethodSpec; import com.squareup.javapoet.ParameterizedTypeName; import com.squareup.javapoet.TypeSpec; import com.squareup.javapoet.WildcardTypeName; import java.io.File; import java.util.Collection; import java.util.List; import java.util.stream.Collectors; import javax.lang.model.element.Modifier; public class HiveWrapperGenerator implements WrapperGenerator { private static final String HIVE_PACKAGE_SUFFIX = "hive"; private static final String GET_TOP_LEVEL_UDF_CLASS_METHOD = "getTopLevelUdfClass"; private static final String GET_STD_UDF_IMPLEMENTATIONS_METHOD = "getStdUdfImplementations"; private static final ClassName HIVE_STD_UDF_WRAPPER_CLASS_NAME = ClassName.bestGuess("com.linkedin.transport.hive.StdUdfWrapper"); @Override public void generateWrappers(WrapperGeneratorContext context) { TransportUDFMetadata udfMetadata = context.getTransportUdfMetadata(); for (String topLevelClass : udfMetadata.getTopLevelClasses()) { generateWrapper(topLevelClass, udfMetadata.getStdUDFImplementations(topLevelClass), context.getSourcesOutputDir()); } } private void generateWrapper(String topLevelClass, Collection<String> implementationClasses, File outputDir) { ClassName topLevelClassName = ClassName.bestGuess(topLevelClass); ClassName wrapperClassName = ClassName.get(topLevelClassName.packageName() + "." + HIVE_PACKAGE_SUFFIX, topLevelClassName.simpleName()); /* Generates -> @Override protected Class<? extends TopLevelStdUDF> getTopLevelUdfClass() { return ${topLevelClass}.class; } */ MethodSpec getTopLevelUdfClassMethod = MethodSpec.methodBuilder(GET_TOP_LEVEL_UDF_CLASS_METHOD) .addAnnotation(Override.class) .returns( ParameterizedTypeName.get(ClassName.get(Class.class), WildcardTypeName.subtypeOf(TopLevelStdUDF.class))) .addModifiers(Modifier.PROTECTED) .addStatement("return $T.class", topLevelClassName) .build(); /* Generates -> @Override protected List<? extends StdUDF> getStdUdfImplementations() { return ImmutableList.of( new ${implementationClasses(0)}(), new ${implementationClasses(1)}(), . . . ); } */ MethodSpec getStdUdfImplementationsMethod = MethodSpec.methodBuilder(GET_STD_UDF_IMPLEMENTATIONS_METHOD) .addAnnotation(Override.class) .returns(ParameterizedTypeName.get(ClassName.get(List.class), WildcardTypeName.subtypeOf(StdUDF.class))) .addModifiers(Modifier.PROTECTED) .addStatement("return $T.of($L)", ImmutableList.class, implementationClasses.stream() .map(clazz -> "new " + clazz + "()") .collect(Collectors.joining(", "))) .build(); /* Generates -> public class ${wrapperClassName} extends StdUdfWrapper { . . . } */ TypeSpec wrapperClass = TypeSpec.classBuilder(wrapperClassName) .addModifiers(Modifier.PUBLIC) .superclass(HIVE_STD_UDF_WRAPPER_CLASS_NAME) .addMethod(getTopLevelUdfClassMethod) .addMethod(getStdUdfImplementationsMethod) .build(); JavaFile javaFile = JavaFile.builder(wrapperClassName.packageName(), wrapperClass).build(); try { javaFile.writeTo(outputDir); } catch (Exception e) { throw new RuntimeException("Error writing wrapper to file", e); } } }