/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.tinkerpop.gremlin.process.traversal.dsl;

import com.squareup.javapoet.ArrayTypeName;
import com.squareup.javapoet.ClassName;
import com.squareup.javapoet.JavaFile;
import com.squareup.javapoet.MethodSpec;
import com.squareup.javapoet.ParameterSpec;
import com.squareup.javapoet.ParameterizedTypeName;
import com.squareup.javapoet.TypeName;
import com.squareup.javapoet.TypeSpec;
import com.squareup.javapoet.TypeVariableName;
import org.apache.tinkerpop.gremlin.process.remote.RemoteConnection;
import org.apache.tinkerpop.gremlin.process.traversal.Traversal;
import org.apache.tinkerpop.gremlin.process.traversal.TraversalStrategies;
import org.apache.tinkerpop.gremlin.process.traversal.dsl.graph.GraphTraversal;
import org.apache.tinkerpop.gremlin.process.traversal.dsl.graph.GraphTraversalSource;
import org.apache.tinkerpop.gremlin.process.traversal.dsl.graph.__;
import org.apache.tinkerpop.gremlin.process.traversal.step.map.AddEdgeStartStep;
import org.apache.tinkerpop.gremlin.process.traversal.step.map.AddVertexStartStep;
import org.apache.tinkerpop.gremlin.process.traversal.step.map.GraphStep;
import org.apache.tinkerpop.gremlin.process.traversal.step.sideEffect.InjectStep;
import org.apache.tinkerpop.gremlin.process.traversal.util.DefaultTraversal;
import org.apache.tinkerpop.gremlin.structure.Edge;
import org.apache.tinkerpop.gremlin.structure.Graph;
import org.apache.tinkerpop.gremlin.structure.Vertex;

import javax.annotation.processing.AbstractProcessor;
import javax.annotation.processing.Filer;
import javax.annotation.processing.Messager;
import javax.annotation.processing.ProcessingEnvironment;
import javax.annotation.processing.RoundEnvironment;
import javax.annotation.processing.SupportedAnnotationTypes;
import javax.annotation.processing.SupportedSourceVersion;
import javax.lang.model.SourceVersion;
import javax.lang.model.element.Element;
import javax.lang.model.element.ElementKind;
import javax.lang.model.element.ExecutableElement;
import javax.lang.model.element.Modifier;
import javax.lang.model.element.TypeElement;
import javax.lang.model.element.VariableElement;
import javax.lang.model.type.DeclaredType;
import javax.lang.model.type.TypeKind;
import javax.lang.model.type.TypeMirror;
import javax.lang.model.type.TypeVariable;
import javax.lang.model.util.Elements;
import javax.lang.model.util.Types;
import javax.tools.Diagnostic;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
 * A custom Java annotation processor for the {@link GremlinDsl} annotation that helps to generate DSLs classes.
 *
 * @author Stephen Mallette (http://stephen.genoprime.com)
 */
@SupportedAnnotationTypes("org.apache.tinkerpop.gremlin.process.traversal.dsl.GremlinDsl")
@SupportedSourceVersion(SourceVersion.RELEASE_8)
public class GremlinDslProcessor extends AbstractProcessor {
    private Messager messager;
    private Elements elementUtils;
    private Filer filer;
    private Types typeUtils;

    @Override
    public synchronized void init(final ProcessingEnvironment processingEnv) {
        super.init(processingEnv);
        messager = processingEnv.getMessager();
        elementUtils = processingEnv.getElementUtils();
        filer = processingEnv.getFiler();
        typeUtils = processingEnv.getTypeUtils();
    }

    @Override
    public boolean process(final Set<? extends TypeElement> annotations, final RoundEnvironment roundEnv) {
        try {
            for (Element dslElement : roundEnv.getElementsAnnotatedWith(GremlinDsl.class)) {
                validateDSL(dslElement);
                final Context ctx = new Context((TypeElement) dslElement);

                // creates the "Traversal" interface using an extension of the GraphTraversal class that has the
                // GremlinDsl annotation on it
                generateTraversalInterface(ctx);

                // create the "DefaultTraversal" class which implements the above generated "Traversal" and can then
                // be used by the "TraversalSource" generated below to spawn new traversal instances.
                generateDefaultTraversal(ctx);

                // create the "TraversalSource" class which is used to spawn traversals from a Graph instance. It will
                // spawn instances of the "DefaultTraversal" generated above.
                generateTraversalSource(ctx);

                // create anonymous traversal for DSL
                generateAnonymousTraversal(ctx);
            }
        } catch (Exception ex) {
            messager.printMessage(Diagnostic.Kind.ERROR, ex.getMessage());
        }

        return true;
    }

    private void generateAnonymousTraversal(final Context ctx) throws IOException {
        final TypeSpec.Builder anonymousClass = TypeSpec.classBuilder("__")
                .addModifiers(Modifier.PUBLIC, Modifier.FINAL);

        // this class is just static methods - it should not be instantiated
        anonymousClass.addMethod(MethodSpec.constructorBuilder()
                .addModifiers(Modifier.PRIVATE)
                .build());

        // add start() method
        anonymousClass.addMethod(MethodSpec.methodBuilder("start")
                .addModifiers(Modifier.PUBLIC, Modifier.STATIC)
                .addTypeVariable(TypeVariableName.get("A"))
                .addStatement("return new $N<>()", ctx.defaultTraversalClazz)
                .returns(ParameterizedTypeName.get(ctx.traversalClassName, TypeVariableName.get("A"), TypeVariableName.get("A")))
                .build());

        // process the methods of the GremlinDsl annotated class
        for (ExecutableElement templateMethod : findMethodsOfElement(ctx.annotatedDslType, null)) {
            final Optional<GremlinDsl.AnonymousMethod> methodAnnotation = Optional.ofNullable(templateMethod.getAnnotation(GremlinDsl.AnonymousMethod.class));

            final String methodName = templateMethod.getSimpleName().toString();

            // either use the direct return type of the DSL specification or override it with specification from
            // GremlinDsl.AnonymousMethod
            final TypeName returnType = methodAnnotation.isPresent() && methodAnnotation.get().returnTypeParameters().length > 0 ?
                    getOverridenReturnTypeDefinition(ctx.traversalClassName, methodAnnotation.get().returnTypeParameters()) :
                    getReturnTypeDefinition(ctx.traversalClassName, templateMethod);
            
            final MethodSpec.Builder methodToAdd = MethodSpec.methodBuilder(methodName)
                    .addModifiers(Modifier.STATIC, Modifier.PUBLIC)
                    .addExceptions(templateMethod.getThrownTypes().stream().map(TypeName::get).collect(Collectors.toList()))
                    .returns(returnType);

            // either use the method type parameter specified from the GremlinDsl.AnonymousMethod or just infer them
            // from the DSL specification. "inferring" relies on convention and sometimes doesn't work for all cases.
            final String startGeneric = methodAnnotation.isPresent() && methodAnnotation.get().methodTypeParameters().length > 0 ?
                    methodAnnotation.get().methodTypeParameters()[0] : "S";
            if (methodAnnotation.isPresent() && methodAnnotation.get().methodTypeParameters().length > 0)
                Stream.of(methodAnnotation.get().methodTypeParameters()).map(TypeVariableName::get).forEach(methodToAdd::addTypeVariable);
            else {
                templateMethod.getTypeParameters().forEach(tp -> methodToAdd.addTypeVariable(TypeVariableName.get(tp)));

                // might have to deal with an "S" (in __ it's usually an "A") - how to make this less bound to that convention?
                final List<? extends TypeMirror> returnTypeArguments = getTypeArguments(templateMethod);
                returnTypeArguments.stream().filter(rtm -> rtm instanceof TypeVariable).forEach(rtm -> {
                    if (((TypeVariable) rtm).asElement().getSimpleName().contentEquals("S"))
                        methodToAdd.addTypeVariable(TypeVariableName.get(((TypeVariable) rtm).asElement().getSimpleName().toString()));
                });
            }

            addMethodBody(methodToAdd, templateMethod, "return __.<" + startGeneric + ">start().$L(", ")", methodName);
            anonymousClass.addMethod(methodToAdd.build());
        }

        // use methods from __ to template them into the DSL __
        final Element anonymousTraversal = elementUtils.getTypeElement(__.class.getCanonicalName());
        final Predicate<ExecutableElement> ignore = ee -> ee.getSimpleName().contentEquals("start");
        for (ExecutableElement templateMethod : findMethodsOfElement(anonymousTraversal, ignore)) {
            final String methodName = templateMethod.getSimpleName().toString();

            final TypeName returnType = getReturnTypeDefinition(ctx.traversalClassName, templateMethod);
            final MethodSpec.Builder methodToAdd = MethodSpec.methodBuilder(methodName)
                    .addModifiers(Modifier.STATIC, Modifier.PUBLIC)
                    .addExceptions(templateMethod.getThrownTypes().stream().map(TypeName::get).collect(Collectors.toList()))
                    .returns(returnType);

            templateMethod.getTypeParameters().forEach(tp -> methodToAdd.addTypeVariable(TypeVariableName.get(tp)));

            if (methodName.equals("__")) {
                for (VariableElement param : templateMethod.getParameters()) {
                    methodToAdd.addParameter(ParameterSpec.get(param));
                }

                methodToAdd.varargs(true);
                methodToAdd.addStatement("return inject(starts)", methodName);
            } else {
                if (templateMethod.getTypeParameters().isEmpty()) {
                    final List<? extends TypeMirror> types = getTypeArguments(templateMethod);
                    addMethodBody(methodToAdd, templateMethod, "return __.<$T>start().$L(", ")", types.get(0), methodName);
                } else {
                    addMethodBody(methodToAdd, templateMethod, "return __.<A>start().$L(", ")", methodName);
                }
            }

            anonymousClass.addMethod(methodToAdd.build());
        }

        final JavaFile traversalSourceJavaFile = JavaFile.builder(ctx.packageName, anonymousClass.build()).build();
        traversalSourceJavaFile.writeTo(filer);
    }

    private void generateTraversalSource(final Context ctx) throws IOException {
        final TypeElement graphTraversalSourceElement = ctx.traversalSourceDslType;
        final TypeSpec.Builder traversalSourceClass = TypeSpec.classBuilder(ctx.traversalSourceClazz)
                .addModifiers(Modifier.PUBLIC)
                .superclass(TypeName.get(graphTraversalSourceElement.asType()));

        // add the required constructors for instantiation
        traversalSourceClass.addMethod(MethodSpec.constructorBuilder()
                .addModifiers(Modifier.PUBLIC)
                .addParameter(Graph.class, "graph")
                .addStatement("super($N)", "graph")
                .build());
        traversalSourceClass.addMethod(MethodSpec.constructorBuilder()
                .addModifiers(Modifier.PUBLIC)
                .addParameter(Graph.class, "graph")
                .addParameter(TraversalStrategies.class, "strategies")
                .addStatement("super($N, $N)", "graph", "strategies")
                .build());
        traversalSourceClass.addMethod(MethodSpec.constructorBuilder()
                .addModifiers(Modifier.PUBLIC)
                .addParameter(RemoteConnection.class, "connection")
                .addStatement("super($N)", "connection")
                .build());

        // override methods to return a the DSL TraversalSource. find GraphTraversalSource class somewhere in the hierarchy
        final Element tinkerPopsGraphTraversalSource = findClassAsElement(graphTraversalSourceElement, GraphTraversalSource.class);
        final Predicate<ExecutableElement> ignore = e -> !(e.getReturnType().getKind() == TypeKind.DECLARED && ((DeclaredType) e.getReturnType()).asElement().getSimpleName().contentEquals(GraphTraversalSource.class.getSimpleName()));
        for (ExecutableElement elementOfGraphTraversalSource : findMethodsOfElement(tinkerPopsGraphTraversalSource, ignore)) {
            // first copy/override methods that return a GraphTraversalSource so that we can instead return
            // the DSL TraversalSource class.
            traversalSourceClass.addMethod(constructMethod(elementOfGraphTraversalSource, ctx.traversalSourceClassName, "",Modifier.PUBLIC));
        }

        // override methods that return GraphTraversal that come from the user defined extension of GraphTraversal
        if (!graphTraversalSourceElement.getSimpleName().contentEquals(GraphTraversalSource.class.getSimpleName())) {
            for (ExecutableElement templateMethod : findMethodsOfElement(graphTraversalSourceElement, null)) {
                final MethodSpec.Builder methodToAdd = MethodSpec.methodBuilder(templateMethod.getSimpleName().toString())
                        .addModifiers(Modifier.PUBLIC)
                        .addAnnotation(Override.class);

                methodToAdd.addStatement("$T clone = this.clone()", ctx.traversalSourceClassName);
                addMethodBody(methodToAdd, templateMethod, "return new $T (clone, super.$L(", ").asAdmin())",
                        ctx.defaultTraversalClassName, templateMethod.getSimpleName());
                methodToAdd.returns(getReturnTypeDefinition(ctx.traversalClassName, templateMethod));

                traversalSourceClass.addMethod(methodToAdd.build());
            }
        }

        if (ctx.generateDefaultMethods) {
            // override methods that return GraphTraversal
            traversalSourceClass.addMethod(MethodSpec.methodBuilder("addV")
                    .addModifiers(Modifier.PUBLIC)
                    .addAnnotation(Override.class)
                    .addStatement("$N clone = this.clone()", ctx.traversalSourceClazz)
                    .addStatement("clone.getBytecode().addStep($T.addV)", GraphTraversal.Symbols.class)
                    .addStatement("$N traversal = new $N(clone)", ctx.defaultTraversalClazz, ctx.defaultTraversalClazz)
                    .addStatement("return ($T) traversal.asAdmin().addStep(new $T(traversal, (String) null))", ctx.traversalClassName, AddVertexStartStep.class)
                    .returns(ParameterizedTypeName.get(ctx.traversalClassName, ClassName.get(Vertex.class), ClassName.get(Vertex.class)))
                    .build());
            traversalSourceClass.addMethod(MethodSpec.methodBuilder("addV")
                    .addModifiers(Modifier.PUBLIC)
                    .addAnnotation(Override.class)
                    .addParameter(String.class, "label")
                    .addStatement("$N clone = this.clone()", ctx.traversalSourceClazz)
                    .addStatement("clone.getBytecode().addStep($T.addV, label)", GraphTraversal.Symbols.class)
                    .addStatement("$N traversal = new $N(clone)", ctx.defaultTraversalClazz, ctx.defaultTraversalClazz)
                    .addStatement("return ($T) traversal.asAdmin().addStep(new $T(traversal, label))", ctx.traversalClassName, AddVertexStartStep.class)
                    .returns(ParameterizedTypeName.get(ctx.traversalClassName, ClassName.get(Vertex.class), ClassName.get(Vertex.class)))
                    .build());
            traversalSourceClass.addMethod(MethodSpec.methodBuilder("addV")
                    .addModifiers(Modifier.PUBLIC)
                    .addAnnotation(Override.class)
                    .addParameter(Traversal.class, "vertexLabelTraversal")
                    .addStatement("$N clone = this.clone()", ctx.traversalSourceClazz)
                    .addStatement("clone.getBytecode().addStep($T.addV, vertexLabelTraversal)", GraphTraversal.Symbols.class)
                    .addStatement("$N traversal = new $N(clone)", ctx.defaultTraversalClazz, ctx.defaultTraversalClazz)
                    .addStatement("return ($T) traversal.asAdmin().addStep(new $T(traversal, vertexLabelTraversal))", ctx.traversalClassName, AddVertexStartStep.class)
                    .returns(ParameterizedTypeName.get(ctx.traversalClassName, ClassName.get(Vertex.class), ClassName.get(Vertex.class)))
                    .build());
            traversalSourceClass.addMethod(MethodSpec.methodBuilder("addE")
                    .addModifiers(Modifier.PUBLIC)
                    .addAnnotation(Override.class)
                    .addParameter(String.class, "label")
                    .addStatement("$N clone = this.clone()", ctx.traversalSourceClazz)
                    .addStatement("clone.getBytecode().addStep($T.addE, label)", GraphTraversal.Symbols.class)
                    .addStatement("$N traversal = new $N(clone)", ctx.defaultTraversalClazz, ctx.defaultTraversalClazz)
                    .addStatement("return ($T) traversal.asAdmin().addStep(new $T(traversal, label))", ctx.traversalClassName, AddEdgeStartStep.class)
                    .returns(ParameterizedTypeName.get(ctx.traversalClassName, ClassName.get(Edge.class), ClassName.get(Edge.class)))
                    .build());
            traversalSourceClass.addMethod(MethodSpec.methodBuilder("addE")
                    .addModifiers(Modifier.PUBLIC)
                    .addAnnotation(Override.class)
                    .addParameter(Traversal.class, "edgeLabelTraversal")
                    .addStatement("$N clone = this.clone()", ctx.traversalSourceClazz)
                    .addStatement("clone.getBytecode().addStep($T.addE, edgeLabelTraversal)", GraphTraversal.Symbols.class)
                    .addStatement("$N traversal = new $N(clone)", ctx.defaultTraversalClazz, ctx.defaultTraversalClazz)
                    .addStatement("return ($T) traversal.asAdmin().addStep(new $T(traversal, edgeLabelTraversal))", ctx.traversalClassName, AddEdgeStartStep.class)
                    .returns(ParameterizedTypeName.get(ctx.traversalClassName, ClassName.get(Edge.class), ClassName.get(Edge.class)))
                    .build());
            traversalSourceClass.addMethod(MethodSpec.methodBuilder("V")
                    .addModifiers(Modifier.PUBLIC)
                    .addAnnotation(Override.class)
                    .addParameter(Object[].class, "vertexIds")
                    .varargs(true)
                    .addStatement("$N clone = this.clone()", ctx.traversalSourceClazz)
                    .addStatement("clone.getBytecode().addStep($T.V, vertexIds)", GraphTraversal.Symbols.class)
                    .addStatement("$N traversal = new $N(clone)", ctx.defaultTraversalClazz, ctx.defaultTraversalClazz)
                    .addStatement("return ($T) traversal.asAdmin().addStep(new $T(traversal, $T.class, true, vertexIds))", ctx.traversalClassName, GraphStep.class, Vertex.class)
                    .returns(ParameterizedTypeName.get(ctx.traversalClassName, ClassName.get(Vertex.class), ClassName.get(Vertex.class)))
                    .build());
            traversalSourceClass.addMethod(MethodSpec.methodBuilder("E")
                    .addModifiers(Modifier.PUBLIC)
                    .addAnnotation(Override.class)
                    .addParameter(Object[].class, "edgeIds")
                    .varargs(true)
                    .addStatement("$N clone = this.clone()", ctx.traversalSourceClazz)
                    .addStatement("clone.getBytecode().addStep($T.E, edgeIds)", GraphTraversal.Symbols.class)
                    .addStatement("$N traversal = new $N(clone)", ctx.defaultTraversalClazz, ctx.defaultTraversalClazz)
                    .addStatement("return ($T) traversal.asAdmin().addStep(new $T(traversal, $T.class, true, edgeIds))", ctx.traversalClassName, GraphStep.class, Edge.class)
                    .returns(ParameterizedTypeName.get(ctx.traversalClassName, ClassName.get(Edge.class), ClassName.get(Edge.class)))
                    .build());
            traversalSourceClass.addMethod(MethodSpec.methodBuilder("inject")
                    .addModifiers(Modifier.PUBLIC)
                    .addAnnotation(Override.class)
                    .addParameter(ArrayTypeName.of(TypeVariableName.get("S")), "starts")
                    .varargs(true)
                    .addTypeVariable(TypeVariableName.get("S"))
                    .addStatement("$N clone = this.clone()", ctx.traversalSourceClazz)
                    .addStatement("clone.getBytecode().addStep($T.inject, starts)", GraphTraversal.Symbols.class)
                    .addStatement("$N traversal = new $N(clone)", ctx.defaultTraversalClazz, ctx.defaultTraversalClazz)
                    .addStatement("return ($T) traversal.asAdmin().addStep(new $T(traversal, starts))", ctx.traversalClassName, InjectStep.class)
                    .returns(ParameterizedTypeName.get(ctx.traversalClassName, TypeVariableName.get("S"), TypeVariableName.get("S")))
                    .build());
            traversalSourceClass.addMethod(MethodSpec.methodBuilder("getAnonymousTraversalClass")
                    .addModifiers(Modifier.PUBLIC)
                    .addAnnotation(Override.class)
                    .addStatement("return Optional.of(__.class)")
                    .returns(ParameterizedTypeName.get(Optional.class, Class.class))
                    .build());
        }

        final JavaFile traversalSourceJavaFile = JavaFile.builder(ctx.packageName, traversalSourceClass.build()).build();
        traversalSourceJavaFile.writeTo(filer);
    }

    private Element findClassAsElement(final Element element, final Class<?> clazz) {
        if (element.getSimpleName().contentEquals(clazz.getSimpleName())) {
            return element;
        }

        final List<? extends TypeMirror> supertypes = typeUtils.directSupertypes(element.asType());
        return findClassAsElement(typeUtils.asElement(supertypes.get(0)), clazz);
    }

    private void generateDefaultTraversal(final Context ctx) throws IOException {
        final TypeSpec.Builder defaultTraversalClass = TypeSpec.classBuilder(ctx.defaultTraversalClazz)
                .addModifiers(Modifier.PUBLIC)
                .addTypeVariables(Arrays.asList(TypeVariableName.get("S"), TypeVariableName.get("E")))
                .superclass(TypeName.get(elementUtils.getTypeElement(DefaultTraversal.class.getCanonicalName()).asType()))
                .addSuperinterface(ParameterizedTypeName.get(ctx.traversalClassName, TypeVariableName.get("S"), TypeVariableName.get("E")));

        // add the required constructors for instantiation
        defaultTraversalClass.addMethod(MethodSpec.constructorBuilder()
                .addModifiers(Modifier.PUBLIC)
                .addStatement("super()")
                .build());
        defaultTraversalClass.addMethod(MethodSpec.constructorBuilder()
                .addModifiers(Modifier.PUBLIC)
                .addParameter(Graph.class, "graph")
                .addStatement("super($N)", "graph")
                .build());
        defaultTraversalClass.addMethod(MethodSpec.constructorBuilder()
                .addModifiers(Modifier.PUBLIC)
                .addParameter(ctx.traversalSourceClassName, "traversalSource")
                .addStatement("super($N)", "traversalSource")
                .build());
        defaultTraversalClass.addMethod(MethodSpec.constructorBuilder()
                .addModifiers(Modifier.PUBLIC)
                .addParameter(ctx.traversalSourceClassName, "traversalSource")
                .addParameter(ctx.graphTraversalAdminClassName, "traversal")
                .addStatement("super($N, $N.asAdmin())", "traversalSource", "traversal")
                .build());

        // add the override
        defaultTraversalClass.addMethod(MethodSpec.methodBuilder("iterate")
                .addModifiers(Modifier.PUBLIC)
                .addAnnotation(Override.class)
                .addStatement("return ($T) super.iterate()", ctx.traversalClassName)
                .returns(ParameterizedTypeName.get(ctx.traversalClassName, TypeVariableName.get("S"), TypeVariableName.get("E")))
                .build());
        defaultTraversalClass.addMethod(MethodSpec.methodBuilder("asAdmin")
                .addModifiers(Modifier.PUBLIC)
                .addAnnotation(Override.class)
                .addStatement("return ($T) super.asAdmin()", GraphTraversal.Admin.class)
                .returns(ParameterizedTypeName.get(ctx.graphTraversalAdminClassName, TypeVariableName.get("S"), TypeVariableName.get("E")))
                .build());
        defaultTraversalClass.addMethod(MethodSpec.methodBuilder("clone")
                .addModifiers(Modifier.PUBLIC)
                .addAnnotation(Override.class)
                .addStatement("return ($T) super.clone()", ctx.defaultTraversalClassName)
                .returns(ParameterizedTypeName.get(ctx.defaultTraversalClassName, TypeVariableName.get("S"), TypeVariableName.get("E")))
                .build());

        final JavaFile defaultTraversalJavaFile = JavaFile.builder(ctx.packageName, defaultTraversalClass.build()).build();
        defaultTraversalJavaFile.writeTo(filer);
    }

    private void generateTraversalInterface(final Context ctx) throws IOException {
        final TypeSpec.Builder traversalInterface = TypeSpec.interfaceBuilder(ctx.traversalClazz)
                .addModifiers(Modifier.PUBLIC)
                .addTypeVariables(Arrays.asList(TypeVariableName.get("S"), TypeVariableName.get("E")))
                .addSuperinterface(TypeName.get(ctx.annotatedDslType.asType()));

        // process the methods of the GremlinDsl annotated class
        for (ExecutableElement templateMethod : findMethodsOfElement(ctx.annotatedDslType, null)) {
            traversalInterface.addMethod(constructMethod(templateMethod, ctx.traversalClassName, ctx.dslName,
                    Modifier.PUBLIC, Modifier.DEFAULT));
        }

        // process the methods of GraphTraversal
        final TypeElement graphTraversalElement = elementUtils.getTypeElement(GraphTraversal.class.getCanonicalName());
        final Predicate<ExecutableElement> ignore = e -> e.getSimpleName().contentEquals("asAdmin") || e.getSimpleName().contentEquals("iterate");
        for (ExecutableElement templateMethod : findMethodsOfElement(graphTraversalElement, ignore)) {
            traversalInterface.addMethod(constructMethod(templateMethod, ctx.traversalClassName, ctx.dslName,
                    Modifier.PUBLIC, Modifier.DEFAULT));
        }

        // there are weird things with generics that require this method to be implemented if it isn't already present
        // in the GremlinDsl annotated class extending from GraphTraversal
        traversalInterface.addMethod(MethodSpec.methodBuilder("iterate")
                .addModifiers(Modifier.PUBLIC, Modifier.DEFAULT)
                .addAnnotation(Override.class)
                .addStatement("$T.super.iterate()", ClassName.get(ctx.annotatedDslType))
                .addStatement("return this")
                .returns(ParameterizedTypeName.get(ctx.traversalClassName, TypeVariableName.get("S"), TypeVariableName.get("E")))
                .build());

        final JavaFile traversalJavaFile = JavaFile.builder(ctx.packageName, traversalInterface.build()).build();
        traversalJavaFile.writeTo(filer);
    }

    private MethodSpec constructMethod(final Element element, final ClassName returnClazz, final String parent,
                                       final Modifier... modifiers) {
        final ExecutableElement templateMethod = (ExecutableElement) element;
        final String methodName = templateMethod.getSimpleName().toString();

        final TypeName returnType = getReturnTypeDefinition(returnClazz, templateMethod);
        final MethodSpec.Builder methodToAdd = MethodSpec.methodBuilder(methodName)
                .addModifiers(modifiers)
                .addAnnotation(Override.class)
                .addExceptions(templateMethod.getThrownTypes().stream().map(TypeName::get).collect(Collectors.toList()))
                .returns(returnType);

        templateMethod.getTypeParameters().forEach(tp -> methodToAdd.addTypeVariable(TypeVariableName.get(tp)));

        final String parentCall = parent.isEmpty() ? "" : parent + ".";
        final String body = "return ($T) " + parentCall + "super.$L(";
        addMethodBody(methodToAdd, templateMethod, body, ")", returnClazz, methodName);

        return methodToAdd.build();
    }

    private void addMethodBody(final MethodSpec.Builder methodToAdd, final ExecutableElement templateMethod,
                               final String startBody, final String endBody, final Object... statementArgs) {
        final List<? extends VariableElement> parameters = templateMethod.getParameters();
        final StringBuilder body = new StringBuilder(startBody);

        final int numberOfParams = parameters.size();
        for (int ix = 0; ix < numberOfParams; ix++) {
            final VariableElement param = parameters.get(ix);
            methodToAdd.addParameter(ParameterSpec.get(param));
            body.append(param.getSimpleName());
            if (ix < numberOfParams - 1) {
                body.append(",");
            }
        }

        body.append(endBody);

        // treat a final array as a varargs param
        if (!parameters.isEmpty() && parameters.get(parameters.size() - 1).asType().getKind() == TypeKind.ARRAY)
            methodToAdd.varargs(true);

        methodToAdd.addStatement(body.toString(), statementArgs);
    }

    private TypeName getOverridenReturnTypeDefinition(final ClassName returnClazz, final String[] typeValues) {
        return ParameterizedTypeName.get(returnClazz, Stream.of(typeValues).map(tv -> {
            try {
                return ClassName.get(Class.forName(tv));
            } catch (ClassNotFoundException cnfe) {
                if (tv.contains("extends")) {
                    final String[] sides = tv.toString().split(" extends ");
                    final TypeVariableName name = TypeVariableName.get(sides[0]);
                    try {
                        name.withBounds(ClassName.get(Class.forName(sides[1])));
                    } catch (Exception ex) {
                        name.withBounds(TypeVariableName.get(sides[1]));
                    }
                    return name;
                } else {
                    return TypeVariableName.get(tv);
                }
            }
        }).collect(Collectors.toList()).toArray(new TypeName[typeValues.length]));
    }

    private TypeName getReturnTypeDefinition(final ClassName returnClazz, final ExecutableElement templateMethod) {
        final List<? extends TypeMirror> returnTypeArguments = getTypeArguments(templateMethod);

        // build a return type with appropriate generic declarations (if such declarations are present)
        return returnTypeArguments.isEmpty() ?
                returnClazz :
                ParameterizedTypeName.get(returnClazz, returnTypeArguments.stream().map(TypeName::get).collect(Collectors.toList()).toArray(new TypeName[returnTypeArguments.size()]));
    }

    private void validateDSL(final Element dslElement) throws ProcessorException {
        if (dslElement.getKind() != ElementKind.INTERFACE)
            throw new ProcessorException(dslElement, "Only interfaces can be annotated with @%s", GremlinDsl.class.getSimpleName());

        final TypeElement typeElement = (TypeElement) dslElement;
        if (!typeElement.getModifiers().contains(Modifier.PUBLIC))
            throw new ProcessorException(dslElement, "The interface %s is not public.", typeElement.getQualifiedName());
    }

    private List<ExecutableElement> findMethodsOfElement(final Element element, final Predicate<ExecutableElement> ignore) {
        final Predicate<ExecutableElement> test = null == ignore ? ee -> false : ignore;
        return element.getEnclosedElements().stream()
                .filter(ee -> ee.getKind() == ElementKind.METHOD)
                .map(ee -> (ExecutableElement) ee)
                .filter(ee -> !test.test(ee)).collect(Collectors.toList());
    }

    private List<? extends TypeMirror> getTypeArguments(final ExecutableElement templateMethod) {
        final DeclaredType returnTypeMirror = (DeclaredType) templateMethod.getReturnType();
        return returnTypeMirror.getTypeArguments();
    }

    private class Context {
        private final TypeElement annotatedDslType;
        private final String packageName;
        private final String dslName;
        private final String traversalClazz;
        private final ClassName traversalClassName;
        private final String traversalSourceClazz;
        private final ClassName traversalSourceClassName;
        private final String defaultTraversalClazz;
        private final ClassName defaultTraversalClassName;
        private final ClassName graphTraversalAdminClassName;
        private final TypeElement traversalSourceDslType;
        private final boolean generateDefaultMethods;

        public Context(final TypeElement dslElement) {
            annotatedDslType = dslElement;

            // gets the annotation on the dsl class/interface
            GremlinDsl gremlinDslAnnotation = dslElement.getAnnotation(GremlinDsl.class);
            generateDefaultMethods = gremlinDslAnnotation.generateDefaultMethods();

            traversalSourceDslType = elementUtils.getTypeElement(gremlinDslAnnotation.traversalSource());
            packageName = getPackageName(dslElement, gremlinDslAnnotation);

            // create the Traversal implementation interface
            dslName = dslElement.getSimpleName().toString();
            final String dslPrefix = dslName.substring(0, dslName.length() - "TraversalDSL".length()); // chop off "TraversalDSL"
            traversalClazz = dslPrefix + "Traversal";
            traversalClassName = ClassName.get(packageName, traversalClazz);
            traversalSourceClazz = dslPrefix + "TraversalSource";
            traversalSourceClassName = ClassName.get(packageName, traversalSourceClazz);
            defaultTraversalClazz = "Default" + traversalClazz;
            defaultTraversalClassName = ClassName.get(packageName, defaultTraversalClazz);
            graphTraversalAdminClassName = ClassName.get(GraphTraversal.Admin.class);
        }

        private String getPackageName(final Element dslElement, final GremlinDsl gremlinDslAnnotation) {
            return gremlinDslAnnotation.packageName().isEmpty() ?
                    elementUtils.getPackageOf(dslElement).getQualifiedName().toString() :
                    gremlinDslAnnotation.packageName();
        }
    }
}