package com.smockin.mockserver.engine; import com.smockin.admin.enums.UserModeEnum; import com.smockin.admin.persistence.dao.RestfulMockDAO; import com.smockin.admin.service.SmockinUserService; import com.smockin.admin.websocket.LiveLoggingHandler; import com.smockin.mockserver.dto.MockServerState; import com.smockin.mockserver.dto.MockedServerConfigDTO; import com.smockin.mockserver.exception.MockServerException; import com.smockin.mockserver.service.*; import com.smockin.mockserver.service.ws.SparkWebSocketEchoService; import com.smockin.utils.GeneralUtils; import com.smockin.utils.LiveLoggingUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; import spark.Request; import spark.Response; import spark.Spark; import java.util.Map; import java.util.Set; import java.util.stream.Collectors; /** * Created by mgallina. */ @Service @Transactional(readOnly = true) public class MockedRestServerEngine implements MockServerEngine<MockedServerConfigDTO> { private final Logger logger = LoggerFactory.getLogger(MockedRestServerEngine.class); @Autowired private RestfulMockDAO restfulMockDAO; @Autowired private RuleEngine ruleEngine; @Autowired private HttpProxyService proxyService; @Autowired private MockOrderingCounterService mockOrderingCounterService; @Autowired private InboundParamMatchService inboundParamMatchService; @Autowired private WebSocketService webSocketService; @Autowired private ServerSideEventService serverSideEventService; @Autowired private MockedRestServerEngineUtils mockedRestServerEngineUtils; @Autowired private LiveLoggingHandler liveLoggingHandler; @Autowired private SmockinUserService smockinUserService; private final Object monitor = new Object(); private MockServerState serverState = new MockServerState(false, 0); private final String wildcardPath = "*"; @Override public void start(final MockedServerConfigDTO config) throws MockServerException { logger.debug("start called"); initServerConfig(config); final boolean isMultiUserMode = UserModeEnum.ACTIVE.equals(smockinUserService.getUserMode()); // Define all web socket routes first as the Spark framework requires this buildWebSocketEndpoints(isMultiUserMode); // Handle Cross-Origin Resource Sharing (CORS) support handleCORS(config); // Next handle all HTTP RESTFul web service routes buildGlobalHttpEndpointsHandler(isMultiUserMode, config); applyTrafficLogging(config.isProxyMode()); initServer(config.getPort()); } @Override public MockServerState getCurrentState() throws MockServerException { synchronized (monitor) { return serverState; } } @Override public void shutdown() throws MockServerException { try { serverSideEventService.interruptAndClearAllHeartBeatThreads(); Spark.stop(); // Having dug around the source code, 'Spark.stop()' runs off a different thread when stopping the server and removing it's state such as routes, etc. // This means that calling 'Spark.port()' immediately after stop, results in an IllegalStateException, as the // 'initialized' flag is checked in the current thread and is still marked as true. // (The error thrown: java.lang.IllegalStateException: This must be done before route mapping has begun) // Short of editing the Spark source to fix this, I have therefore had to add this hack to buy the 'stop' thread time to complete. Thread.sleep(3000); synchronized (monitor) { serverState.setRunning(false); } clearState(); } catch (Throwable ex) { throw new MockServerException(ex); } } void initServer(final int port) throws MockServerException { logger.debug("initServer called"); try { clearState(); Spark.init(); // Blocks the current thread (using a CountDownLatch under the hood) until the server is fully initialised. Spark.awaitInitialization(); synchronized (monitor) { serverState.setRunning(true); serverState.setPort(port); } } catch (Throwable ex) { throw new MockServerException(ex); } } void initServerConfig(final MockedServerConfigDTO config) { logger.debug("initServerConfig called"); if (logger.isDebugEnabled()) logger.debug(config.toString()); Spark.port(config.getPort()); Spark.threadPool(config.getMaxThreads(), config.getMinThreads(), config.getTimeOutMillis()); } void buildWebSocketEndpoints(final boolean isMultiUserMode) { Spark.webSocket("/*", new SparkWebSocketEchoService(webSocketService, isMultiUserMode)); } private void applyTrafficLogging(final boolean isUsingProxyMode) { // Live logging filter Spark.before((request, response) -> { final String traceId = GeneralUtils.generateUUID(); request.attribute(GeneralUtils.LOG_REQ_ID, traceId); response.raw() .addHeader(GeneralUtils.LOG_REQ_ID, traceId); final Map<String, String> reqHeaders = request .headers() .stream() .collect(Collectors.toMap(h -> h, h -> request.headers(h))); reqHeaders.put(GeneralUtils.LOG_REQ_ID, traceId); liveLoggingHandler.broadcast( LiveLoggingUtils.buildLiveLogInboundDTO( request.attribute(GeneralUtils.LOG_REQ_ID), request.requestMethod(), request.pathInfo(), reqHeaders, request.body(), isUsingProxyMode, GeneralUtils.extractAllRequestParams(request))); }); Spark.afterAfter((request, response) -> { if (serverSideEventService.SSE_EVENT_STREAM_HEADER.equals(response.raw().getHeader(HttpHeaders.CONTENT_TYPE))) { return; } boolean isProxiedResponse = isProxiedResponse(isUsingProxyMode, response); final Map<String, String> respHeaders = response .raw() .getHeaderNames() .stream() .filter(h -> !GeneralUtils.PROXIED_RESPONSE_HEADER.equalsIgnoreCase(h)) .collect(Collectors.toMap(h -> h, h -> response.raw().getHeader(h))); respHeaders.put(GeneralUtils.LOG_REQ_ID, request.attribute(GeneralUtils.LOG_REQ_ID)); liveLoggingHandler.broadcast( LiveLoggingUtils.buildLiveLogOutboundDTO( request.attribute(GeneralUtils.LOG_REQ_ID), response.raw().getStatus(), respHeaders, response.body(), isUsingProxyMode, isProxiedResponse)); }); } void buildGlobalHttpEndpointsHandler(final boolean isMultiUserMode, final MockedServerConfigDTO config) { logger.debug("buildGlobalHttpEndpointsHandler called"); Spark.head(wildcardPath, (request, response) -> mockedRestServerEngineUtils.loadMockedResponse(request, response, isMultiUserMode, config) .orElseGet(() -> handleNotFoundResponse(response))); Spark.get(wildcardPath, (request, response) -> { if (isWebSocketUpgradeRequest(request)) { response.status(HttpStatus.OK.value()); return null; } return mockedRestServerEngineUtils.loadMockedResponse(request, response, isMultiUserMode, config) .orElseGet(() -> handleNotFoundResponse(response)); }); Spark.post(wildcardPath, (request, response) -> mockedRestServerEngineUtils.loadMockedResponse(request, response, isMultiUserMode, config) .orElseGet(() -> handleNotFoundResponse(response))); Spark.put(wildcardPath, (request, response) -> mockedRestServerEngineUtils.loadMockedResponse(request, response, isMultiUserMode, config) .orElseGet(() -> handleNotFoundResponse(response))); Spark.delete(wildcardPath, (request, response) -> mockedRestServerEngineUtils.loadMockedResponse(request, response, isMultiUserMode, config) .orElseGet(() -> handleNotFoundResponse(response))); Spark.patch(wildcardPath, (request, response) -> mockedRestServerEngineUtils.loadMockedResponse(request, response, isMultiUserMode, config) .orElseGet(() -> handleNotFoundResponse(response))); } private String handleNotFoundResponse(final Response response) { response.status(HttpStatus.NOT_FOUND.value()); return ""; } private boolean isWebSocketUpgradeRequest(final Request request) { final Set<String> headerNames = request.headers(); return headerNames.contains(HttpHeaders.UPGRADE) && headerNames.contains("Sec-WebSocket-Key") && "websocket".equalsIgnoreCase(request.headers(HttpHeaders.UPGRADE)); } void clearState() { // Proxy related state webSocketService.clearSession(); proxyService.clearAllSessions(); mockOrderingCounterService.clearState(); serverSideEventService.clearState(); } void handleCORS(final MockedServerConfigDTO config) { final String enableCors = config.getNativeProperties().get(GeneralUtils.ENABLE_CORS_PARAM); if (!Boolean.TRUE.toString().equalsIgnoreCase(enableCors)) { return; } Spark.options("/*", (request, response) -> { final String accessControlRequestHeaders = request.headers(HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS); if (accessControlRequestHeaders != null) { response.header(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS, accessControlRequestHeaders); } final String accessControlRequestMethod = request.headers(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD); if (accessControlRequestMethod != null) { response.header(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS, accessControlRequestMethod); } return HttpStatus.OK.name(); }); Spark.before((request, response) -> { response.header(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN, wildcardPath); }); } boolean isProxiedResponse(final boolean isUsingProxyMode, final Response response) { if (!isUsingProxyMode) { return false; } final String proxyHeader = response .raw() .getHeader(GeneralUtils.PROXIED_RESPONSE_HEADER); return (proxyHeader != null && Boolean.TRUE.toString().equals(proxyHeader)); } }