package org.codefx.demo.junit5; import org.junit.jupiter.api.extension.AfterAllCallback; import org.junit.jupiter.api.extension.AfterTestExecutionCallback; import org.junit.jupiter.api.extension.BeforeAllCallback; import org.junit.jupiter.api.extension.BeforeTestExecutionCallback; import org.junit.jupiter.api.extension.ExtensionContext; import org.junit.jupiter.api.extension.ExtensionContext.Namespace; import static java.lang.System.currentTimeMillis; import static java.util.Collections.singletonMap; import static org.junit.platform.commons.support.AnnotationSupport.isAnnotated; class BenchmarkExtension implements BeforeAllCallback, BeforeTestExecutionCallback, AfterTestExecutionCallback, AfterAllCallback { private static final Namespace NAMESPACE = Namespace.create("org", "codefx", "BenchmarkExtension"); // EXTENSION POINTS @Override public void beforeAll(ExtensionContext context) { if (!shouldBeBenchmarked(context)) return; storeNowAsLaunchTime(context, LaunchTimeKey.CLASS); } @Override public void beforeTestExecution(ExtensionContext context) { if (!shouldBeBenchmarked(context)) return; storeNowAsLaunchTime(context, LaunchTimeKey.TEST); } @Override public void afterTestExecution(ExtensionContext context) { if (!shouldBeBenchmarked(context)) return; long launchTime = loadLaunchTime(context, LaunchTimeKey.TEST); long elapsedTime = currentTimeMillis() - launchTime; report("Test", context, elapsedTime); } @Override public void afterAll(ExtensionContext context) { if (!shouldBeBenchmarked(context)) return; long launchTime = loadLaunchTime(context, LaunchTimeKey.CLASS); long elapsedTime = currentTimeMillis() - launchTime; report("Test container", context, elapsedTime); } // HELPER private static boolean shouldBeBenchmarked(ExtensionContext context) { return context.getElement() .map(el -> isAnnotated(el, Benchmark.class)) .orElse(false); } private static void storeNowAsLaunchTime(ExtensionContext context, LaunchTimeKey key) { context.getStore(NAMESPACE).put(key, currentTimeMillis()); } private static long loadLaunchTime(ExtensionContext context, LaunchTimeKey key) { return context.getStore(NAMESPACE).get(key, long.class); } private static void report(String unit, ExtensionContext context, long elapsedTime) { String message = String.format("%s '%s' took %d ms.", unit, context.getDisplayName(), elapsedTime); context.publishReportEntry("Benchmark", message); } private enum LaunchTimeKey { CLASS, TEST } }