package com.rtbhouse.utils.avro; import com.google.common.base.Charsets; import com.google.common.collect.Lists; import com.google.common.hash.HashFunction; import com.google.common.hash.Hashing; import com.sun.codemodel.JBlock; import com.sun.codemodel.JCodeModel; import com.sun.codemodel.JDefinedClass; import org.apache.avro.Schema; import org.apache.avro.SchemaNormalization; import org.apache.avro.io.parsing.Symbol; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import javax.tools.JavaCompiler; import javax.tools.ToolProvider; import java.io.File; import java.io.IOException; import java.io.OutputStream; import java.io.PrintStream; import java.lang.reflect.Field; import java.util.Collections; import java.util.ListIterator; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ThreadLocalRandom; public abstract class FastDeserializerGeneratorBase<T> { private static final Logger LOGGER = LoggerFactory.getLogger(FastDeserializerGeneratorBase.class); public static final String GENERATED_PACKAGE_NAME = "com.rtbhouse.utils.avro.deserialization.generated"; public static final String GENERATED_SOURCES_PATH = "/com/rtbhouse/utils/avro/deserialization/generated/"; protected JCodeModel codeModel; protected JDefinedClass deserializerClass; protected final Schema writer; protected final Schema reader; private File destination; private ClassLoader classLoader; private String compileClassPath; FastDeserializerGeneratorBase(Schema writer, Schema reader, File destination, ClassLoader classLoader, String compileClassPath) { this.writer = writer; this.reader = reader; this.destination = destination; this.classLoader = classLoader; this.compileClassPath = compileClassPath; codeModel = new JCodeModel(); } public abstract FastDeserializer<T> generateDeserializer(); @SuppressWarnings("unchecked") protected Class<FastDeserializer<T>> compileClass(final String className) throws IOException, ClassNotFoundException { final OutputStream infoLoggingStream = LoggingOutputStream.infoLoggingStream(LOGGER); final OutputStream errorLoggingStream = LoggingOutputStream.errorLoggingStream(LOGGER); codeModel.build(destination, new PrintStream(infoLoggingStream)); JavaCompiler compiler = ToolProvider.getSystemJavaCompiler(); if (compiler == null) { throw new FastDeserializerGeneratorException("no system java compiler is available"); } int compileResult; if (compileClassPath != null) { compileResult = compiler.run( null, infoLoggingStream, errorLoggingStream, "-cp", compileClassPath, destination.getAbsolutePath() + GENERATED_SOURCES_PATH + className + ".java" ); } else { compileResult = compiler.run( null, infoLoggingStream, errorLoggingStream, destination.getAbsolutePath() + GENERATED_SOURCES_PATH + className + ".java"); } if (compileResult != 0) { throw new FastDeserializerGeneratorException("unable to compile: " + className); } return (Class<FastDeserializer<T>>) classLoader.loadClass(GENERATED_PACKAGE_NAME + "." + className); } protected ListIterator<Symbol> actionIterator(FieldAction action) { ListIterator<Symbol> actionIterator = null; if (action.getSymbolIterator() != null) { actionIterator = action.getSymbolIterator(); } else if (action.getSymbol().production != null) { actionIterator = Lists.newArrayList(reverseSymbolArray(action.getSymbol().production)) .listIterator(); } else { actionIterator = Collections.emptyListIterator(); } while (actionIterator.hasNext()) { Symbol symbol = actionIterator.next(); if (symbol instanceof Symbol.ErrorAction) { throw new FastDeserializerGeneratorException(((Symbol.ErrorAction) symbol).msg); } if (symbol instanceof Symbol.FieldOrderAction) { break; } } return actionIterator; } protected void forwardToExpectedDefault(ListIterator<Symbol> symbolIterator) { Symbol symbol; while (symbolIterator.hasNext()) { symbol = symbolIterator.next(); if (symbol instanceof Symbol.ErrorAction) { throw new FastDeserializerGeneratorException(((Symbol.ErrorAction) symbol).msg); } if (symbol instanceof Symbol.DefaultStartAction) { return; } } throw new FastDeserializerGeneratorException("DefaultStartAction symbol expected!"); } protected FieldAction seekFieldAction(boolean shouldReadCurrent, Schema.Field field, ListIterator<Symbol> symbolIterator) { Schema.Type type = field.schema().getType(); if (!shouldReadCurrent) { return FieldAction.fromValues(type, false, EMPTY_SYMBOL); } boolean shouldRead = true; Symbol fieldSymbol = END_SYMBOL; if (Schema.Type.RECORD.equals(type)) { if (symbolIterator.hasNext()) { fieldSymbol = symbolIterator.next(); if (fieldSymbol instanceof Symbol.SkipAction) { return FieldAction.fromValues(type, false, fieldSymbol); } else { symbolIterator.previous(); } } return FieldAction.fromValues(type, true, symbolIterator); } while (symbolIterator.hasNext()) { Symbol symbol = symbolIterator.next(); if (symbol instanceof Symbol.ErrorAction) { throw new FastDeserializerGeneratorException(((Symbol.ErrorAction) symbol).msg); } if (symbol instanceof Symbol.SkipAction) { shouldRead = false; fieldSymbol = symbol; break; } if (symbol instanceof Symbol.WriterUnionAction) { if (symbolIterator.hasNext()) { symbol = symbolIterator.next(); if (symbol instanceof Symbol.Alternative) { shouldRead = true; fieldSymbol = symbol; break; } } } if (symbol.kind == Symbol.Kind.TERMINAL) { shouldRead = true; if (symbolIterator.hasNext()) { symbol = symbolIterator.next(); if (symbol instanceof Symbol.Repeater) { fieldSymbol = symbol; } else { fieldSymbol = symbolIterator.previous(); } } else if (!symbolIterator.hasNext() && getSymbolPrintName(symbol) != null) { fieldSymbol = symbol; } break; } } return FieldAction.fromValues(type, shouldRead, fieldSymbol); } protected static final class FieldAction { private Schema.Type type; private boolean shouldRead; private Symbol symbol; private ListIterator<Symbol> symbolIterator; private FieldAction(Schema.Type type, boolean shouldRead, Symbol symbol) { this.type = type; this.shouldRead = shouldRead; this.symbol = symbol; } private FieldAction(Schema.Type type, boolean shouldRead, ListIterator<Symbol> symbolIterator) { this.type = type; this.shouldRead = shouldRead; this.symbolIterator = symbolIterator; } public Schema.Type getType() { return type; } public boolean getShouldRead() { return shouldRead; } public Symbol getSymbol() { return symbol; } public ListIterator<Symbol> getSymbolIterator() { return symbolIterator; } public static FieldAction fromValues(Schema.Type type, boolean read, Symbol symbol) { return new FieldAction(type, read, symbol); } public static FieldAction fromValues(Schema.Type type, boolean read, ListIterator<Symbol> symbolIterator) { return new FieldAction(type, read, symbolIterator); } } protected static final Symbol EMPTY_SYMBOL = new Symbol(Symbol.Kind.TERMINAL, new Symbol[] {}) { }; protected static final Symbol END_SYMBOL = new Symbol(Symbol.Kind.TERMINAL, new Symbol[] {}) { }; protected static Symbol[] reverseSymbolArray(Symbol[] symbols) { Symbol[] reversedSymbols = new Symbol[symbols.length]; for (int i = 0; i < symbols.length; i++) { reversedSymbols[symbols.length - i - 1] = symbols[i]; } return reversedSymbols; } public static String getClassName(Schema writerSchema, Schema readerSchema, String description) { Integer writerSchemaId = Math.abs(getSchemaId(writerSchema)); Integer readerSchemaId = Math.abs(getSchemaId(readerSchema)); if (Schema.Type.RECORD.equals(readerSchema.getType())) { return readerSchema.getName() + description + "Deserializer" + writerSchemaId + "_" + readerSchemaId; } else if (Schema.Type.ARRAY.equals(readerSchema.getType())) { return "Array" + description + "Deserializer" + writerSchemaId + "_" + readerSchemaId; } else if (Schema.Type.MAP.equals(readerSchema.getType())) { return "Map" + description + "Deserializer" + writerSchemaId + "_" + readerSchemaId; } throw new FastDeserializerGeneratorException("Unsupported return type: " + readerSchema.getType()); } protected static String getVariableName(String name) { return name + nextRandomInt(); } protected static String getSymbolPrintName(Symbol symbol) { String printName; try { Field field = symbol.getClass().getDeclaredField("printName"); field.setAccessible(true); printName = (String) field.get(symbol); field.setAccessible(false); } catch (ReflectiveOperationException e) { throw new FastDeserializerGeneratorException(e); } return printName; } private static final Map<Schema, Integer> SCHEMA_IDS_CACHE = new ConcurrentHashMap<>(); private static final HashFunction HASH_FUNCTION = Hashing.murmur3_128(); public static int getSchemaId(Schema schema) { Integer schemaId = SCHEMA_IDS_CACHE.get(schema); if (schemaId == null) { String schemaString = SchemaNormalization.toParsingForm(schema); schemaId = HASH_FUNCTION.hashString(schemaString, Charsets.UTF_8).asInt(); SCHEMA_IDS_CACHE.put(schema, schemaId); } return schemaId; } protected static int nextRandomInt() { return Math.abs(ThreadLocalRandom.current().nextInt()); } protected static void assignBlockToBody(Object codeContainer, JBlock body) { try { Field field = codeContainer.getClass().getDeclaredField("body"); field.setAccessible(true); field.set(codeContainer, body); field.setAccessible(false); } catch (ReflectiveOperationException e) { throw new FastDeserializerGeneratorException(e); } } }