package com.github.jenspiegsa.wiremockextension; import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig; import static java.util.Collections.singletonList; import static java.util.stream.Collectors.toList; import static org.junit.platform.commons.util.ReflectionUtils.makeAccessible; import java.lang.annotation.Annotation; import java.lang.reflect.Field; import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.Optional; import java.util.stream.Stream; import org.junit.jupiter.api.extension.AfterEachCallback; import org.junit.jupiter.api.extension.BeforeEachCallback; import org.junit.jupiter.api.extension.ExtensionConfigurationException; import org.junit.jupiter.api.extension.ExtensionContext; import org.junit.jupiter.api.extension.TestInstancePostProcessor; import org.junit.platform.commons.support.AnnotationSupport; import org.junit.platform.commons.util.AnnotationUtils; import org.junit.platform.commons.util.ReflectionUtils; import com.github.tomakehurst.wiremock.WireMockServer; import com.github.tomakehurst.wiremock.client.VerificationException; import com.github.tomakehurst.wiremock.client.WireMock; import com.github.tomakehurst.wiremock.core.Options; import com.github.tomakehurst.wiremock.verification.LoggedRequest; import com.github.tomakehurst.wiremock.verification.NearMiss; /** * @author Jens Piegsa */ public class WireMockExtension implements BeforeEachCallback, AfterEachCallback, TestInstancePostProcessor { private boolean generalFailOnUnmatchedRequests; /** * {@link ExtensionContext.Namespace} in which WireMockServers are stored, * keyed by test class. */ private static final ExtensionContext.Namespace NAMESPACE = ExtensionContext.Namespace.create(WireMockExtension.class); // This constructor is invoked by JUnit via reflection @SuppressWarnings("unused") private WireMockExtension() { this(true); } public WireMockExtension(final boolean failOnUnmatchedRequests) { generalFailOnUnmatchedRequests = failOnUnmatchedRequests; } @Override public void postProcessTestInstance(final Object testInstance, final ExtensionContext context) throws Exception { final List<WireMockServer> managedServers = retrieveAnnotatedFields(context, Managed.class, WireMockServer.class).stream() .map(field -> ReflectionUtils.readFieldValue(field, testInstance)) .map(Optional::get) .map(WireMockServer.class::cast) .map(Objects::requireNonNull) .collect(toList()); if (managedServers.isEmpty()) { Options options = null; for (final Field field : retrieveAnnotatedFields(context, ConfigureWireMock.class, Options.class)) { if (options == null) { options = (Options) makeAccessible(field).get(testInstance); } else { throw new ExtensionConfigurationException("@ConfigureWireMock only valid once per class."); } } if (options == null) { options = wireMockConfig(); } final List<Field> injectedServerFields = retrieveAnnotatedFields(context, InjectServer.class, WireMockServer.class); if (!injectedServerFields.isEmpty()) { final WireMockServer server = new WireMockServer(options); for (final Field field : injectedServerFields) { makeAccessible(field).set(testInstance, server); } context.getStore(NAMESPACE).put(testInstance.getClass(), singletonList(server)); } } else { context.getStore(NAMESPACE).put(testInstance.getClass(), managedServers); } } @Override public void beforeEach(final ExtensionContext context) { final Optional<WireMockSettings> wireMockSettings = retrieveAnnotation(context, WireMockSettings.class); generalFailOnUnmatchedRequests = wireMockSettings .map(WireMockSettings::failOnUnmatchedRequests) .orElse(generalFailOnUnmatchedRequests); final List<WireMockServer> wireMockServers = collectServers(context); if (wireMockServers.isEmpty()) { // Simple case final WireMockServer server = new WireMockServer(); context.getStore(NAMESPACE).put(context.getRequiredTestClass(), singletonList(server)); startServer(server); } else { wireMockServers.forEach(WireMockExtension::startServer); } } @Override public void afterEach(final ExtensionContext context) { final List<WireMockServer> wireMockServers = collectServers(context); // Stopping all servers first wireMockServers.forEach(WireMockExtension::stopServer); wireMockServers.forEach(this::checkForUnmatchedRequests); } private void checkForUnmatchedRequests(final WireMockServer server) { final boolean mustCheck = Optional.of(server) .filter(ManagedWireMockServer.class::isInstance) .map(ManagedWireMockServer.class::cast) .map(ManagedWireMockServer::failOnUnmatchedRequests) .orElse(generalFailOnUnmatchedRequests); if (mustCheck) { final List<LoggedRequest> unmatchedRequests = server.findAllUnmatchedRequests(); if (!unmatchedRequests.isEmpty()) { final List<NearMiss> nearMisses = server.findNearMissesForAllUnmatchedRequests(); throw nearMisses.isEmpty() ? VerificationException.forUnmatchedRequests(unmatchedRequests) : VerificationException.forUnmatchedNearMisses(nearMisses); } } } private static <A extends Annotation> Optional<A> retrieveAnnotation(final ExtensionContext context, final Class<A> annotationType) { Optional<ExtensionContext> currentContext = Optional.of(context); Optional<A> annotation = Optional.empty(); while (currentContext.isPresent() && !annotation.isPresent()) { annotation = AnnotationSupport.findAnnotation(currentContext.get().getElement(), annotationType); currentContext = currentContext.get().getParent(); } return annotation; } private static List<Field> retrieveAnnotatedFields(final ExtensionContext context, final Class<? extends Annotation> annotationType, final Class<?> fieldType) { return context.getElement() .filter(Class.class::isInstance) .map(Class.class::cast) .map(testInstanceClass -> AnnotationUtils.findAnnotatedFields(testInstanceClass, annotationType, field -> fieldType.isAssignableFrom(field.getType())) ) .orElseGet(Collections::emptyList); } private static void startServer(final WireMockServer server) { if (!server.isRunning()) { server.start(); WireMock.configureFor("localhost", server.port()); } } private static void stopServer(final WireMockServer server) { server.stop(); } private static List<WireMockServer> collectServers(final ExtensionContext context) { return collectTestClasses(context) .map(testClass -> context.getStore(NAMESPACE).get(testClass)) .filter(Objects::nonNull) .map(List.class::cast) .flatMap(Collection::stream) .map(WireMockServer.class::cast) .collect(toList()); } private static Stream<Class<?>> collectTestClasses(final ExtensionContext context) { return Stream.concat( Stream.of(context.getRequiredTestClass()), context.getParent() .filter(parentContext -> parentContext != context.getRoot()) .map(WireMockExtension::collectTestClasses) .orElseGet(Stream::empty) ).distinct(); } }