package cz.habarta.typescript.generator.spring; import cz.habarta.typescript.generator.Settings; import cz.habarta.typescript.generator.TsType; import cz.habarta.typescript.generator.TypeProcessor; import cz.habarta.typescript.generator.TypeScriptGenerator; import cz.habarta.typescript.generator.parser.JaxrsApplicationParser; import cz.habarta.typescript.generator.parser.MethodParameterModel; import cz.habarta.typescript.generator.parser.PathTemplate; import cz.habarta.typescript.generator.parser.RestApplicationModel; import cz.habarta.typescript.generator.parser.RestApplicationParser; import cz.habarta.typescript.generator.parser.RestApplicationType; import cz.habarta.typescript.generator.parser.RestMethodModel; import cz.habarta.typescript.generator.parser.RestQueryParam; import cz.habarta.typescript.generator.parser.SourceType; import cz.habarta.typescript.generator.type.JTypeWithNullability; import cz.habarta.typescript.generator.util.GenericsResolver; import cz.habarta.typescript.generator.util.Pair; import cz.habarta.typescript.generator.util.Utils; import static cz.habarta.typescript.generator.util.Utils.getInheritanceChain; import java.lang.reflect.Method; import java.lang.reflect.Parameter; import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; import java.util.ArrayList; import java.util.Arrays; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.stream.Collectors; import java.util.stream.Stream; import org.springframework.boot.SpringApplication; import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.context.ConfigurableApplicationContext; import org.springframework.core.BridgeMethodResolver; import org.springframework.core.annotation.AnnotatedElementUtils; import org.springframework.core.annotation.AnnotationUtils; import org.springframework.data.domain.Pageable; import org.springframework.http.ResponseEntity; import org.springframework.stereotype.Component; import org.springframework.web.bind.annotation.PathVariable; import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMethod; import org.springframework.web.bind.annotation.RequestParam; import org.springframework.web.bind.annotation.ValueConstants; public class SpringApplicationParser extends RestApplicationParser { // This factory class is instantiated using reflections! public static class Factory extends RestApplicationParser.Factory { @Override public TypeProcessor getSpecificTypeProcessor() { return (javaType, context) -> { final Class<?> rawClass = Utils.getRawClassOrNull(javaType); if (rawClass != null) { for (Map.Entry<Class<?>, TsType> entry : getStandardEntityClassesMapping().entrySet()) { final Class<?> cls = entry.getKey(); final TsType type = entry.getValue(); if (cls.isAssignableFrom(rawClass) && type != null) { return new TypeProcessor.Result(type); } } if (getDefaultExcludedClassNames().contains(rawClass.getName())) { return new TypeProcessor.Result(TsType.Any); } } return null; }; } @Override public RestApplicationParser create(Settings settings, TypeProcessor commonTypeProcessor) { return new SpringApplicationParser(settings, commonTypeProcessor); } } public SpringApplicationParser(Settings settings, TypeProcessor commonTypeProcessor) { super(settings, commonTypeProcessor, new RestApplicationModel(RestApplicationType.Spring)); } @Override public JaxrsApplicationParser.Result tryParse(SourceType<?> sourceType) { if (!(sourceType.type instanceof Class<?>)) { return null; } final Class<?> cls = (Class<?>) sourceType.type; // application final SpringBootApplication app = AnnotationUtils.findAnnotation(cls, SpringBootApplication.class); if (app != null) { if (settings.scanSpringApplication) { TypeScriptGenerator.getLogger().verbose("Scanning Spring application: " + cls.getName()); final ClassLoader originalContextClassLoader = Thread.currentThread().getContextClassLoader(); try { Thread.currentThread().setContextClassLoader(settings.classLoader); final SpringApplicationHelper springApplicationHelper = new SpringApplicationHelper(settings.classLoader, cls); final List<Class<?>> restControllers = springApplicationHelper.findRestControllers(); return new JaxrsApplicationParser.Result(restControllers.stream() .map(controller -> new SourceType<Type>(controller, cls, "<scanned>")) .collect(Collectors.toList()) ); } finally { Thread.currentThread().setContextClassLoader(originalContextClassLoader); } } else { return null; } } // controller final Component component = AnnotationUtils.findAnnotation(cls, Component.class); if (component != null) { TypeScriptGenerator.getLogger().verbose("Parsing Spring component: " + cls.getName()); final JaxrsApplicationParser.Result result = new JaxrsApplicationParser.Result(); final RequestMapping requestMapping = AnnotatedElementUtils.findMergedAnnotation(cls, RequestMapping.class); final String path = requestMapping != null && requestMapping.path() != null && requestMapping.path().length != 0 ? requestMapping.path()[0] : null; final JaxrsApplicationParser.ResourceContext context = new JaxrsApplicationParser.ResourceContext(cls, path); parseController(result, context, cls); return result; } return null; } private class SpringApplicationHelper extends SpringApplication { private final ClassLoader classLoader; public SpringApplicationHelper(ClassLoader classLoader, Class<?>... primarySources) { super(primarySources); this.classLoader = classLoader; } public List<Class<?>> findRestControllers() { try (ConfigurableApplicationContext context = createApplicationContext()) { load(context, getAllSources().toArray()); withSystemProperty("server.port", "0", () -> { context.refresh(); }); final List<Class<?>> classes = Stream.of(context.getBeanDefinitionNames()) .map(beanName -> context.getBeanFactory().getBeanDefinition(beanName).getBeanClassName()) .filter(Objects::nonNull) .filter(className -> isClassNameExcluded == null || !isClassNameExcluded.test(className)) .map(className -> { try { return classLoader.loadClass(className); } catch (ClassNotFoundException e) { throw new RuntimeException(e); } }) .filter(instance -> AnnotationUtils.findAnnotation(instance, Component.class) != null) .collect(Collectors.toList()); return classes; } } } private static void withSystemProperty(String name, String value, Runnable runnable) { final String original = System.getProperty(name); try { System.setProperty(name, value); runnable.run(); } finally { if (original != null) { System.setProperty(name, original); } else { System.getProperties().remove(name); } } } private void parseController(JaxrsApplicationParser.Result result, JaxrsApplicationParser.ResourceContext context, Class<?> controllerClass) { // parse controller methods final List<Method> methods = getAllRequestMethods(controllerClass); methods.sort(Utils.methodComparator()); for (Method method : methods) { parseControllerMethod(result, context, controllerClass, method); } } private List<Method> getAllRequestMethods(Class<?> cls) { List<Method> currentlyResolvedMethods = new ArrayList<>(); getInheritanceChain(cls) .forEach(clazz -> { for (Method method : clazz.getDeclaredMethods()) { final RequestMapping requestMapping = AnnotatedElementUtils.findMergedAnnotation(method, RequestMapping.class); if (requestMapping != null) { addOrReplaceMethod(currentlyResolvedMethods, method); } } }); return currentlyResolvedMethods; } private void addOrReplaceMethod(List<Method> resolvedMethods, Method newMethod) { int methodIndex = getMethodIndex(resolvedMethods, newMethod); if (methodIndex == -1) { resolvedMethods.add(newMethod); return; } final Method bridgedMethod = BridgeMethodResolver.findBridgedMethod(newMethod); int bridgedMethodIndex = getMethodIndex(resolvedMethods, bridgedMethod); if (bridgedMethodIndex == -1 || bridgedMethodIndex == methodIndex) { resolvedMethods.set(methodIndex, bridgedMethod); } else { resolvedMethods.set(bridgedMethodIndex, bridgedMethod); resolvedMethods.remove(methodIndex); } } private int getMethodIndex(List<Method> resolvedMethods, Method newMethod) { for (int i = 0; i < resolvedMethods.size(); i++) { Method currMethod = resolvedMethods.get(i); if (!currMethod.getName().equals(newMethod.getName())) continue; if (!Arrays.equals(currMethod.getParameterTypes(), newMethod.getParameterTypes())) continue; return i; } return -1; } // https://docs.spring.io/spring/docs/current/spring-framework-reference/web.html#mvc-ann-methods private void parseControllerMethod(JaxrsApplicationParser.Result result, JaxrsApplicationParser.ResourceContext context, Class<?> controllerClass, Method method) { final RequestMapping requestMapping = AnnotatedElementUtils.findMergedAnnotation(method, RequestMapping.class); if (requestMapping != null) { // subContext context = context.subPath(requestMapping.path().length == 0 ? "" : requestMapping.path()[0]); final Map<String, Type> pathParamTypes = new LinkedHashMap<>(); for (Parameter parameter : method.getParameters()) { final PathVariable pathVariableAnnotation = AnnotationUtils.findAnnotation(parameter, PathVariable.class); if (pathVariableAnnotation != null) { String pathVariableName = pathVariableAnnotation.value(); // https://docs.spring.io/spring/docs/3.2.x/spring-framework-reference/html/mvc.html#mvc-ann-requestmapping-uri-templates // Can be empty if the URI template variable matches the method argument if (pathVariableName.isEmpty()) { pathVariableName = parameter.getName(); } pathParamTypes.put(pathVariableName, parameter.getParameterizedType()); } } context = context.subPathParamTypes(pathParamTypes); final RequestMethod httpMethod = requestMapping.method().length == 0 ? RequestMethod.GET : requestMapping.method()[0]; // path parameters final PathTemplate pathTemplate = PathTemplate.parse(context.path); final Map<String, Type> contextPathParamTypes = context.pathParamTypes; final List<MethodParameterModel> pathParams = pathTemplate.getParts().stream() .filter(PathTemplate.Parameter.class::isInstance) .map(PathTemplate.Parameter.class::cast) .map(parameter -> { final Type type = contextPathParamTypes.get(parameter.getOriginalName()); final Type paramType = type != null ? type : String.class; foundType(result, paramType, controllerClass, method.getName()); return new MethodParameterModel(parameter.getValidName(), paramType); }) .collect(Collectors.toList()); // query parameters final List<RestQueryParam> queryParams = new ArrayList<>(); for (Parameter parameter : method.getParameters()) { if (parameter.getType() == Pageable.class) { queryParams.add(new RestQueryParam.Single(new MethodParameterModel("page", Long.class), false)); foundType(result, Long.class, controllerClass, method.getName()); queryParams.add(new RestQueryParam.Single(new MethodParameterModel("size", Long.class), false)); foundType(result, Long.class, controllerClass, method.getName()); queryParams.add(new RestQueryParam.Single(new MethodParameterModel("sort", String.class), false)); foundType(result, String.class, controllerClass, method.getName()); } else { final RequestParam requestParamAnnotation = AnnotationUtils.findAnnotation(parameter, RequestParam.class); if (requestParamAnnotation != null) { final boolean isRequired = requestParamAnnotation.required() && requestParamAnnotation.defaultValue().equals(ValueConstants.DEFAULT_NONE); queryParams.add(new RestQueryParam.Single(new MethodParameterModel(firstOf( requestParamAnnotation.value(), parameter.getName() ), parameter.getParameterizedType()), isRequired)); foundType(result, parameter.getParameterizedType(), controllerClass, method.getName()); } } } // entity parameter final MethodParameterModel entityParameter = getEntityParameter(controllerClass, method); if (entityParameter != null) { foundType(result, entityParameter.getType(), controllerClass, method.getName()); } final Type modelReturnType = parseReturnType(controllerClass, method); foundType(result, modelReturnType, controllerClass, method.getName()); model.getMethods().add(new RestMethodModel(controllerClass, method.getName(), modelReturnType, controllerClass, httpMethod.name(), context.path, pathParams, queryParams, entityParameter, null)); } } private Type parseReturnType(Class<?> controllerClass, Method method) { final Class<?> returnType = method.getReturnType(); final Type parsedReturnType = settings.getTypeParser().getMethodReturnType(method); final Type plainReturnType = JTypeWithNullability.getPlainType(parsedReturnType); final Type modelReturnType; if (returnType == void.class) { modelReturnType = returnType; } else if (plainReturnType instanceof ParameterizedType && returnType == ResponseEntity.class) { final ParameterizedType parameterizedReturnType = (ParameterizedType) plainReturnType; modelReturnType = parameterizedReturnType.getActualTypeArguments()[0]; } else { modelReturnType = parsedReturnType; } return GenericsResolver.resolveType(controllerClass, modelReturnType, method.getDeclaringClass()); } private MethodParameterModel getEntityParameter(Class<?> controller, Method method) { final List<Type> parameterTypes = settings.getTypeParser().getMethodParameterTypes(method); final List<Pair<Parameter, Type>> parameters = Utils.zip(Arrays.asList(method.getParameters()), parameterTypes); for (Pair<Parameter, Type> pair : parameters) { final RequestBody requestBodyAnnotation = AnnotationUtils.findAnnotation(pair.getValue1(), RequestBody.class); if (requestBodyAnnotation != null) { final Type resolvedType = GenericsResolver.resolveType(controller, pair.getValue2(), method.getDeclaringClass()); return new MethodParameterModel(pair.getValue1().getName(), resolvedType); } } return null; } private static Map<Class<?>, TsType> getStandardEntityClassesMapping() { if (standardEntityClassesMapping == null) { final Map<Class<?>, TsType> map = new LinkedHashMap<>(); standardEntityClassesMapping = map; } return standardEntityClassesMapping; } private static Map<Class<?>, TsType> standardEntityClassesMapping; private static List<String> getDefaultExcludedClassNames() { return Arrays.asList( ); } private static String firstOf(String... values) { return Stream.of(values).filter(it -> it != null && !it.isEmpty()).findFirst().orElse(""); } }