/* * Copyright 2016-2018 The OpenTracing Authors * * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except * in compliance with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed under the License * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express * or implied. See the License for the specific language governing permissions and limitations under * the License. */ package io.opentracing.contrib.web.servlet.filter; import java.io.IOException; import java.util.Collections; import java.util.EnumSet; import java.util.List; import java.util.concurrent.Callable; import java.util.concurrent.TimeUnit; import java.util.regex.Pattern; import javax.servlet.AsyncContext; import javax.servlet.DispatcherType; import javax.servlet.Filter; import javax.servlet.FilterChain; import javax.servlet.FilterConfig; import javax.servlet.ServletException; import javax.servlet.ServletRequest; import javax.servlet.ServletResponse; import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.eclipse.jetty.server.Server; import org.eclipse.jetty.server.ServerConnector; import org.eclipse.jetty.servlet.FilterHolder; import org.eclipse.jetty.servlet.ServletContextHandler; import org.eclipse.jetty.servlet.ServletHolder; import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.mockito.Mockito; import io.opentracing.Scope; import io.opentracing.Span; import io.opentracing.SpanContext; import io.opentracing.Tracer; import io.opentracing.mock.MockSpan; import io.opentracing.mock.MockTracer; import io.opentracing.util.ThreadLocalScopeManager; /** * @author Pavol Loffay */ public abstract class AbstractJettyTest { // jetty starts on a random port private int serverPort; protected String contextPath = "/context"; protected Server jettyServer; protected MockTracer mockTracer; @Before public void beforeTest() throws Exception { mockTracer = Mockito.spy(new MockTracer(new ThreadLocalScopeManager(), MockTracer.Propagator.TEXT_MAP)); ServletContextHandler servletContext = new ServletContextHandler(); servletContext.setContextPath(contextPath); servletContext.addServlet(TestServlet.class, "/hello"); ServletHolder asyncServletHolder = new ServletHolder(new AsyncServlet(mockTracer)); servletContext.addServlet(asyncServletHolder, "/async"); asyncServletHolder.setAsyncSupported(true); servletContext.addServlet(AsyncImmediateExitServlet.class, "/asyncImmediateExit") .setAsyncSupported(true); ServletHolder timeoutServletHolder = new ServletHolder(new AsyncTimeoutServlet()); timeoutServletHolder.setAsyncSupported(true); servletContext.addServlet(timeoutServletHolder, "/asyncTimeout"); servletContext.addServlet(new ServletHolder(new LocalSpanServlet(mockTracer)), "/localSpan"); servletContext.addServlet(new ServletHolder(new CurrentSpanServlet(mockTracer)), "/currentSpan"); servletContext.addServlet(ExceptionServlet.class, "/servletException"); servletContext.addFilter(new FilterHolder(tracingFilter()), "/*", EnumSet.of(DispatcherType.REQUEST, DispatcherType.FORWARD, DispatcherType.ASYNC, DispatcherType.ERROR, DispatcherType.INCLUDE)); servletContext.addFilter(ErrorFilter.class, "/*", EnumSet.of(DispatcherType.REQUEST)); initServletContext(servletContext); jettyServer = new Server(0); jettyServer.setHandler(servletContext); jettyServer.start(); serverPort = ((ServerConnector)jettyServer.getConnectors()[0]).getLocalPort(); } protected void initServletContext(ServletContextHandler servletContext) { } @After public void afterTest() throws Exception { jettyServer.stop(); jettyServer.join(); } protected Filter tracingFilter() { return new TracingFilter(mockTracer, Collections.singletonList(ServletFilterSpanDecorator.STANDARD_TAGS), Pattern.compile("/health")); } public String localRequestUrl(String path) { return "http://localhost:" + serverPort + ("/".equals(contextPath) ? "" : contextPath) + path; } public static void assertOnErrors(List<MockSpan> spans) { for (MockSpan mockSpan: spans) { Assert.assertEquals(mockSpan.generatedErrors().toString(), 0, mockSpan.generatedErrors().size()); } } Callable<Integer> reportedSpansSize() { return new Callable<Integer>() { @Override public Integer call() throws Exception { return mockTracer.finishedSpans().size(); } }; } public static class TestServlet extends HttpServlet { @Override public void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException { // Check mock tracer is available in servlet context response.setStatus(getServletContext().getAttribute(Tracer.class.getName()) instanceof MockTracer ? HttpServletResponse.SC_ACCEPTED : HttpServletResponse.SC_EXPECTATION_FAILED); } } public static class LocalSpanServlet extends HttpServlet { private io.opentracing.Tracer tracer; public LocalSpanServlet(Tracer tracer) { this.tracer = tracer; } @Override public void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException { SpanContext spanContext = (SpanContext)request.getAttribute(TracingFilter.SERVER_SPAN_CONTEXT); tracer.buildSpan("localSpan") .asChildOf(spanContext) .start() .finish(); } } public static class CurrentSpanServlet extends HttpServlet { private io.opentracing.Tracer tracer; public CurrentSpanServlet(Tracer tracer) { this.tracer = tracer; } @Override public void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException { tracer.activeSpan().setTag("CurrentSpan", true); } } public static class ExceptionServlet extends HttpServlet { public static final String EXCEPTION_MESSAGE = ExceptionServlet.class.getName() + "message"; @Override public void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException { throw new ServletException(EXCEPTION_MESSAGE); } } public static class AsyncServlet extends HttpServlet { public static int ASYNC_SLEEP_TIME_MS = 250; private io.opentracing.Tracer tracer; public AsyncServlet(Tracer tracer) { this.tracer = tracer; } @Override public void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException { final AsyncContext asyncContext = request.startAsync(request, response); // TODO: This could be avoided by using an OpenTracing aware Runnable (when available) final Span cont = tracer.activeSpan(); asyncContext.start(new Runnable() { @Override public void run() { HttpServletResponse asyncResponse = (HttpServletResponse) asyncContext.getResponse(); try (Scope activeScope = tracer.scopeManager().activate(cont)) { try { Thread.sleep(ASYNC_SLEEP_TIME_MS); asyncResponse.setStatus(204); } catch (InterruptedException e) { e.printStackTrace(); asyncResponse.setStatus(500); } finally { asyncContext.complete(); } } } }); } } public static class AsyncImmediateExitServlet extends HttpServlet { @Override public void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException { final AsyncContext asyncContext = request.startAsync(request, response); response.setStatus(204); asyncContext.complete(); } } public static class AsyncTimeoutServlet extends HttpServlet { @Override public void doGet(final HttpServletRequest request, final HttpServletResponse response) throws ServletException, IOException { // avoid retries on timeout if (request.getAttribute("timedOut") != null) { return; } request.setAttribute("timedOut", true); final AsyncContext asyncContext = request.startAsync(request, response); asyncContext.setTimeout(10); asyncContext.start(new Runnable() { @Override public void run() { try { TimeUnit.MILLISECONDS.sleep(200); } catch (InterruptedException e) { e.printStackTrace(); } finally { asyncContext.complete(); } } }); } } public static class ErrorFilter implements Filter { public static final String EXCEPTION_MESSAGE = ErrorFilter.class.getName() + "message"; @Override public void init(FilterConfig filterConfig) throws ServletException {} @Override public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { HttpServletRequest httpServletRequest = (HttpServletRequest) request; if ("/filterException".equals(httpServletRequest.getServletPath())) { throw new RuntimeException(EXCEPTION_MESSAGE); } chain.doFilter(request, response); } @Override public void destroy() {} } }