package com.massfords.jaxb; import com.sun.codemodel.JBlock; import com.sun.codemodel.JClass; import com.sun.codemodel.JDefinedClass; import com.sun.codemodel.JExpr; import com.sun.codemodel.JFieldVar; import com.sun.codemodel.JMethod; import com.sun.codemodel.JMod; import com.sun.codemodel.JPackage; import com.sun.codemodel.JTypeVar; import com.sun.codemodel.JVar; import com.sun.tools.xjc.outline.ClassOutline; import com.sun.tools.xjc.outline.Outline; import java.util.Collections; import java.util.Set; import java.util.function.Function; import static com.massfords.jaxb.ClassDiscoverer.allConcreteClasses; /** * Creates a traversing visitor. This visitor pairs a visitor and a traverser. The result is a visitor that * will traverse the entire graph and visit each of the nodes using the provided visitor. * * @author markford */ class CreateTraversingVisitorClass extends CodeCreator { private final JDefinedClass progressMonitor; private final JDefinedClass visitor; private final JDefinedClass traverser; /** * Function that accepts a type name and returns the name of the method to * create. This encapsulates the behavior associated with the includeType * flag. This applies to the visitor methods. */ private final Function<String,String> visitMethodNamer; /** * Function that accepts a type name and returns the name of the method to * create. This encapsulates the behavior associated with the includeType * flag. This applies to the traverser methods. */ private final Function<String,String> traverseMethodNamer; CreateTraversingVisitorClass(JDefinedClass visitor, JDefinedClass progressMonitor, JDefinedClass traverser, Outline outline, JPackage jPackage, Function<String, String> visitMethodNamer, Function<String, String> traverseMethodNamer) { super(outline, jPackage); this.visitor = visitor; this.traverser = traverser; this.progressMonitor = progressMonitor; this.visitMethodNamer = visitMethodNamer; this.traverseMethodNamer = traverseMethodNamer; } @Override protected void run(Set<ClassOutline> classes, Set<JClass> directClasses) { JDefinedClass traversingVisitor = getOutline().getClassFactory().createClass(getPackage(), "TraversingVisitor", null); final JTypeVar returnType = traversingVisitor.generify("R"); final JTypeVar exceptionType = traversingVisitor.generify("E", Throwable.class); final JClass narrowedVisitor = visitor.narrow(returnType).narrow(exceptionType); final JClass narrowedTraverser = traverser.narrow(exceptionType); traversingVisitor._implements(narrowedVisitor); JMethod ctor = traversingVisitor.constructor(JMod.PUBLIC); ctor.param(narrowedTraverser, "aTraverser"); ctor.param(narrowedVisitor, "aVisitor"); JFieldVar fieldTraverseFirst = traversingVisitor.field(JMod.PRIVATE, Boolean.TYPE, "traverseFirst"); JFieldVar fieldVisitor = traversingVisitor.field(JMod.PRIVATE, narrowedVisitor, "visitor"); JFieldVar fieldTraverser = traversingVisitor.field(JMod.PRIVATE, narrowedTraverser, "traverser"); JFieldVar fieldMonitor = traversingVisitor.field(JMod.PRIVATE, progressMonitor, "progressMonitor"); addGetterAndSetter(traversingVisitor, fieldTraverseFirst); addGetterAndSetter(traversingVisitor, fieldVisitor); addGetterAndSetter(traversingVisitor, fieldTraverser); addGetterAndSetter(traversingVisitor, fieldMonitor); ctor.body().assign(fieldTraverser, JExpr.ref("aTraverser")); ctor.body().assign(fieldVisitor, JExpr.ref("aVisitor")); setOutput(traversingVisitor); for(JClass jc : allConcreteClasses(classes, Collections.emptySet())) { generate(traversingVisitor, returnType, exceptionType, jc); } for(JClass jc : directClasses) { generateForDirectClass(traversingVisitor, returnType, exceptionType, jc); } } private void generateForDirectClass(JDefinedClass traversingVisitor, JTypeVar returnType, JTypeVar exceptionType, JClass implClass) { // add method impl to traversing visitor JMethod travViz; String visitMethodName = visitMethodNamer.apply(implClass.name()); travViz = traversingVisitor.method(JMod.PUBLIC, returnType, visitMethodName); travViz._throws(exceptionType); JVar beanVar = travViz.param(implClass, "aBean"); travViz.annotate(Override.class); JBlock travVizBloc = travViz.body(); addTraverseBlock(travViz, beanVar, true); JVar retVal = travVizBloc.decl(returnType, "returnVal"); travVizBloc.assign(retVal, JExpr.invoke(JExpr.invoke("getVisitor"), visitMethodName).arg(beanVar)); travVizBloc._if(JExpr.ref("progressMonitor").ne(JExpr._null()))._then().invoke(JExpr.ref("progressMonitor"), "visited").arg(beanVar); addTraverseBlock(travViz, beanVar, false); travVizBloc._return(retVal); } private void generate(JDefinedClass traversingVisitor, JTypeVar returnType, JTypeVar exceptionType, JClass implClass) { // add method impl to traversing visitor JMethod travViz; travViz = traversingVisitor.method(JMod.PUBLIC, returnType, visitMethodNamer.apply(implClass.name())); travViz._throws(exceptionType); JVar beanVar = travViz.param(implClass, "aBean"); travViz.annotate(Override.class); JBlock travVizBloc = travViz.body(); addTraverseBlock(travViz, beanVar, true); JVar retVal = travVizBloc.decl(returnType, "returnVal"); travVizBloc.assign(retVal, JExpr.invoke(beanVar, "accept").arg(JExpr.invoke("getVisitor"))); travVizBloc._if(JExpr.ref("progressMonitor").ne(JExpr._null()))._then().invoke(JExpr.ref("progressMonitor"), "visited").arg(beanVar); // case to traverse after the visit addTraverseBlock(travViz, beanVar, false); travVizBloc._return(retVal); } private void addTraverseBlock(JMethod travViz, JVar beanVar, boolean flag) { JBlock travVizBloc = travViz.body(); // case to traverse before the visit JBlock block = travVizBloc._if(JExpr.ref("traverseFirst").eq(JExpr.lit(flag)))._then(); String traverseMethodName = traverseMethodNamer.apply(beanVar.type().name()); block.invoke(JExpr.invoke("getTraverser"), traverseMethodName).arg(beanVar).arg(JExpr._this()); block._if(JExpr.ref("progressMonitor").ne(JExpr._null()))._then().invoke(JExpr.ref("progressMonitor"), "traversed").arg(beanVar); } /** * Convenience method to add a getter and setter method for the given field. * * @param traversingVisitor * @param field */ private void addGetterAndSetter(JDefinedClass traversingVisitor, JFieldVar field) { String propName = Character.toUpperCase(field.name().charAt(0)) + field.name().substring(1); traversingVisitor.method(JMod.PUBLIC, field.type(), "get" + propName).body()._return(field); JMethod setVisitor = traversingVisitor.method(JMod.PUBLIC, void.class, "set" + propName); JVar visParam = setVisitor.param(field.type(), "aVisitor"); setVisitor.body().assign(field, visParam); } }