package com.nike.riposte.server.testutils; import com.nike.backstopper.apierror.ApiError; import com.nike.backstopper.model.DefaultErrorContractDTO; import com.nike.internal.util.Pair; import com.nike.wingtips.Span; import com.nike.wingtips.lifecyclelistener.SpanLifecycleListener; import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.commons.lang3.RandomUtils; import java.io.BufferedReader; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStreamReader; import java.net.ServerSocket; import java.util.ArrayList; import java.util.Arrays; import java.util.Base64; import java.util.Collection; import java.util.Collections; import java.util.Comparator; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.function.Consumer; import java.util.function.Function; import java.util.zip.DataFormatException; import java.util.zip.Deflater; import java.util.zip.GZIPInputStream; import java.util.zip.GZIPOutputStream; import java.util.zip.Inflater; import io.netty.bootstrap.Bootstrap; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.Channel; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelPipeline; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.handler.codec.http.DefaultFullHttpRequest; import io.netty.handler.codec.http.DefaultHttpHeaders; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.HttpClientCodec; import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.codec.http.HttpObject; import io.netty.handler.codec.http.HttpObjectAggregator; import io.netty.handler.codec.http.HttpVersion; import io.netty.util.CharsetUtil; import static io.netty.util.CharsetUtil.UTF_8; import static org.apache.commons.lang3.StringUtils.split; import static org.apache.commons.lang3.StringUtils.substringAfter; import static org.apache.commons.lang3.StringUtils.substringBefore; import static org.assertj.core.api.Assertions.assertThat; /** * Helper methods for working with component tests. * * @author Nic Munroe */ public class ComponentTestUtils { private static final String HEADER_SEPARATOR = ":"; private static final String payloadDictionary = "aBcDefGhiJkLmN@#$%"; private static final ObjectMapper objectMapper = new ObjectMapper(); public static int findFreePort() throws IOException { try (ServerSocket serverSocket = new ServerSocket(0)) { return serverSocket.getLocalPort(); } } public static class SpanRecorder implements SpanLifecycleListener { public final List<Span> completedSpans = Collections.synchronizedList(new ArrayList<>()); @Override public void spanStarted(Span span) { } @Override public void spanSampled(Span span) { } @Override public void spanCompleted(Span span) { completedSpans.add(span); } } public static void waitUntilSpanRecorderHasExpectedNumSpans( SpanRecorder spanRecorder, int expectedNumSpans ) { waitUntilSpanRecorderHasExpectedNumSpans(spanRecorder, expectedNumSpans, 5000); } public static void waitUntilSpanRecorderHasExpectedNumSpans( SpanRecorder spanRecorder, int expectedNumSpans, long timeoutMillis ) { waitUntilCollectionHasSize(spanRecorder.completedSpans, expectedNumSpans, timeoutMillis, "spanRecorder"); // Before we return we need to sort completedSpans by start time (in reverse order to mimic what normally // happens with spans where the last-created is first-completed). We need to do this sort because running // these tests on travis CI can get weird and we can get them completing and arriving in the list in // out-of-expected-order state. spanRecorder.completedSpans.sort(Comparator.comparingLong(Span::getSpanStartTimeNanos).reversed()); } public static void waitUntilCollectionHasSize( Collection<?> collection, int expectedSize, long timeoutMillis, String collectionName ) { long startTimeMillis = System.currentTimeMillis(); while (collection.size() < expectedSize) { try { Thread.sleep(10); } catch (InterruptedException e) { throw new RuntimeException(e); } long timeSinceStart = System.currentTimeMillis() - startTimeMillis; if (timeSinceStart > timeoutMillis) { throw new RuntimeException( collectionName + " did not have the expected size of " + expectedSize + " after waiting " + timeoutMillis + " milliseconds" ); } } } public static String generatePayload(int payloadSize) { return generatePayload(payloadSize, payloadDictionary); } public static String generatePayload(int payloadSize, String dictionary) { StringBuilder payload = new StringBuilder(); for(int i = 0; i < payloadSize; i++) { int randomInt = RandomUtils.nextInt(0, dictionary.length() - 1); payload.append(dictionary.charAt(randomInt)); } return payload.toString(); } public static ByteBuf createByteBufPayload(int payloadSize) { return Unpooled.wrappedBuffer(generatePayload(payloadSize).getBytes(UTF_8)); } public static byte[] gzipPayload(String payload) { ByteArrayOutputStream bytesOut = new ByteArrayOutputStream(); try (GZIPOutputStream gzipOutputStream = new GZIPOutputStream(bytesOut)) { byte[] payloadBytes = payload.getBytes(UTF_8); gzipOutputStream.write(payloadBytes); gzipOutputStream.finish(); return bytesOut.toByteArray(); } catch (IOException e) { throw new RuntimeException(e); } } public static String ungzipPayload(byte[] compressed) { try { if ((compressed == null) || (compressed.length == 0)) { throw new RuntimeException("Null/empty compressed payload. is_null=" + (compressed == null)); } final StringBuilder outStr = new StringBuilder(); final GZIPInputStream gis = new GZIPInputStream(new ByteArrayInputStream(compressed)); final BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(gis, "UTF-8")); String line; while ((line = bufferedReader.readLine()) != null) { outStr.append(line); } return outStr.toString(); } catch(IOException ex) { throw new RuntimeException(ex); } } public static byte[] deflatePayload(String payload) { Deflater deflater = new Deflater(6, false); byte[] payloadBytes = payload.getBytes(UTF_8); deflater.setInput(payloadBytes); deflater.finish(); ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); byte[] buffer = new byte[1024]; while (!deflater.finished()) { int count = deflater.deflate(buffer); outputStream.write(buffer, 0, count); } return outputStream.toByteArray(); } public static String inflatePayload(byte[] compressed) { Inflater inflater = new Inflater(); inflater.setInput(compressed); ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); byte[] buffer = new byte[1024]; while (!inflater.finished()) { try { int count = inflater.inflate(buffer); outputStream.write(buffer, 0, count); } catch (DataFormatException e) { throw new RuntimeException(e); } } return new String(outputStream.toByteArray(), UTF_8); } public static String base64Encode(byte[] bytes) { return Base64.getEncoder().encodeToString(bytes); } public static byte[] base64Decode(String encodedStr) { return Base64.getDecoder().decode(encodedStr); } public enum CompressionType { GZIP(ComponentTestUtils::gzipPayload, ComponentTestUtils::ungzipPayload, HttpHeaders.Values.GZIP), DEFLATE(ComponentTestUtils::deflatePayload, ComponentTestUtils::inflatePayload, HttpHeaders.Values.DEFLATE), IDENTITY(s -> s.getBytes(UTF_8), b -> new String(b, UTF_8), HttpHeaders.Values.IDENTITY); private final Function<String, byte[]> compressionFunction; private final Function<byte[], String> decompressionFunction; public final String contentEncodingHeaderValue; CompressionType(Function<String, byte[]> compressionFunction, Function<byte[], String> decompressionFunction, String contentEncodingHeaderValue) { this.compressionFunction = compressionFunction; this.decompressionFunction = decompressionFunction; this.contentEncodingHeaderValue = contentEncodingHeaderValue; } public byte[] compress(String s) { return compressionFunction.apply(s); } public String decompress(byte[] compressed) { return decompressionFunction.apply(compressed); } } public static Bootstrap createNettyHttpClientBootstrap() { Bootstrap bootstrap = new Bootstrap(); bootstrap.group(new NioEventLoopGroup()) .channel(NioSocketChannel.class) .handler(new ChannelInitializer<SocketChannel>() { @Override protected void initChannel(SocketChannel ch) throws Exception { ChannelPipeline p = ch.pipeline(); p.addLast(new HttpClientCodec()); p.addLast(new HttpObjectAggregator(Integer.MAX_VALUE)); p.addLast("clientResponseHandler", new SimpleChannelInboundHandler<HttpObject>() { @Override protected void channelRead0(ChannelHandlerContext ctx, HttpObject msg) throws Exception { throw new RuntimeException("Client response handler was not setup before the call"); } }); } }); return bootstrap; } public static Channel connectNettyHttpClientToLocalServer(Bootstrap bootstrap, int port) throws InterruptedException { return bootstrap.connect("localhost", port).sync().channel(); } public static CompletableFuture<NettyHttpClientResponse> setupNettyHttpClientResponseHandler(Channel ch) { return setupNettyHttpClientResponseHandler(ch, null); } public static CompletableFuture<NettyHttpClientResponse> setupNettyHttpClientResponseHandler( Channel ch, Consumer<ChannelPipeline> pipelineAdjuster ) { CompletableFuture<NettyHttpClientResponse> responseFromServerFuture = new CompletableFuture<>(); ch.pipeline().replace("clientResponseHandler", "clientResponseHandler", new SimpleChannelInboundHandler<HttpObject>() { @Override protected void channelRead0(ChannelHandlerContext ctx, HttpObject msg) throws Exception { if (msg instanceof FullHttpResponse) { // Store the proxyServer response for asserting on later. responseFromServerFuture.complete(new NettyHttpClientResponse((FullHttpResponse) msg)); } else { // Should never happen. throw new RuntimeException("Received unexpected message type: " + msg.getClass()); } } }); if (pipelineAdjuster != null) pipelineAdjuster.accept(ch.pipeline()); return responseFromServerFuture; } public static NettyHttpClientResponse executeNettyHttpClientCall( Channel ch, FullHttpRequest request, long incompleteCallTimeoutMillis ) throws ExecutionException, InterruptedException, TimeoutException { return executeNettyHttpClientCall(ch, request, incompleteCallTimeoutMillis, null); } public static NettyHttpClientResponse executeNettyHttpClientCall( Channel ch, FullHttpRequest request, long incompleteCallTimeoutMillis, Consumer<ChannelPipeline> pipelineAdjuster ) throws ExecutionException, InterruptedException, TimeoutException { CompletableFuture<NettyHttpClientResponse> responseFuture = setupNettyHttpClientResponseHandler(ch, pipelineAdjuster); // Send the request. ch.writeAndFlush(request); // Wait for the response to be received return responseFuture.get(incompleteCallTimeoutMillis, TimeUnit.MILLISECONDS); } public static class NettyHttpClientResponse { public final int statusCode; public final HttpHeaders headers; public final String payload; public final byte[] payloadBytes; public final FullHttpResponse fullHttpResponse; public NettyHttpClientResponse(FullHttpResponse fullHttpResponse) { this.statusCode = fullHttpResponse.status().code(); this.headers = fullHttpResponse.headers(); ByteBuf content = fullHttpResponse.content(); this.payloadBytes = new byte[content.readableBytes()]; content.getBytes(content.readerIndex(), this.payloadBytes); this.payload = new String(this.payloadBytes, UTF_8); this.fullHttpResponse = fullHttpResponse; } } public static NettyHttpClientResponse executeRequest( FullHttpRequest request, int port, long incompleteCallTimeoutMillis ) throws InterruptedException, TimeoutException, ExecutionException { return executeRequest(request, port, incompleteCallTimeoutMillis, null); } public static NettyHttpClientResponse executeRequest( FullHttpRequest request, int port, long incompleteCallTimeoutMillis, Consumer<ChannelPipeline> pipelineAdjuster ) throws InterruptedException, TimeoutException, ExecutionException { Bootstrap bootstrap = createNettyHttpClientBootstrap(); try { // Connect to the proxyServer. Channel ch = connectNettyHttpClientToLocalServer(bootstrap, port); try { return executeNettyHttpClientCall(ch, request, incompleteCallTimeoutMillis, pipelineAdjuster); } finally { ch.close(); } } finally { bootstrap.config().group().shutdownGracefully(); } } public static NettyHttpClientRequestBuilder request() { return new NettyHttpClientRequestBuilder(); } public static class NettyHttpClientRequestBuilder { private HttpMethod method; private String uri; private String payload; private HttpHeaders headers = new DefaultHttpHeaders(); private Consumer<ChannelPipeline> pipelineAdjuster; public NettyHttpClientRequestBuilder withMethod(HttpMethod method) { this.method = method; return this; } public NettyHttpClientRequestBuilder withUri(String uri) { this.uri = uri; return this; } public NettyHttpClientRequestBuilder withPaylod(String payload) { this.payload = payload; return this; } public NettyHttpClientRequestBuilder withKeepAlive() { headers.set(HttpHeaders.Names.CONNECTION, HttpHeaders.Values.KEEP_ALIVE); return this; } public NettyHttpClientRequestBuilder withConnectionClose() { headers.set(HttpHeaders.Names.CONNECTION, HttpHeaders.Values.CLOSE); return this; } public NettyHttpClientRequestBuilder withHeader(CharSequence key, Object value) { this.headers.set(key, value); return this; } public NettyHttpClientRequestBuilder withHeaders(Iterable<Pair<String, Object>> headers) { for (Pair<String, Object> header : headers) { withHeader(header.getKey(), header.getValue()); } return this; } @SafeVarargs public final NettyHttpClientRequestBuilder withHeaders(Pair<String, Object>... headers) { return withHeaders(Arrays.asList(headers)); } public NettyHttpClientRequestBuilder withPipelineAdjuster(Consumer<ChannelPipeline> pipelineAdjuster) { this.pipelineAdjuster = pipelineAdjuster; return this; } public FullHttpRequest build() { ByteBuf content; if (payload != null) content = Unpooled.copiedBuffer(payload, CharsetUtil.UTF_8); else content = Unpooled.buffer(0); DefaultFullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, method, uri, content); if (headers != null) request.headers().set(headers); return request; } public NettyHttpClientResponse execute(int port, long incompleteCallTimeoutMillis) throws Exception { return executeRequest(build(), port, incompleteCallTimeoutMillis, pipelineAdjuster); } public NettyHttpClientResponse execute(Channel ch, long incompleteCallTimeoutMillis) throws InterruptedException, ExecutionException, TimeoutException { return executeNettyHttpClientCall(ch, build(), incompleteCallTimeoutMillis, pipelineAdjuster); } } public static void verifyErrorReceived(String response, int responseStatusCode, ApiError expectedApiError) throws IOException { assertThat(responseStatusCode).isEqualTo(expectedApiError.getHttpStatusCode()); DefaultErrorContractDTO responseAsError = objectMapper.readValue(response, DefaultErrorContractDTO.class); assertThat(responseAsError.errors).hasSize(1); assertThat(responseAsError.errors.get(0).code).isEqualTo(expectedApiError.getErrorCode()); assertThat(responseAsError.errors.get(0).message).isEqualTo(expectedApiError.getMessage()); assertThat(responseAsError.errors.get(0).metadata).isEqualTo(expectedApiError.getMetadata()); } public static String extractBodyFromRawRequestOrResponse(String rawRequestOrResponse) { return substringAfter(rawRequestOrResponse, "\r\n\r\n"); //body start after \r\n\r\n combo } public static String extractFullBodyFromChunks(String chunkedBody) { if (!chunkedBody.contains("\r\n")) { return chunkedBody; } // https://www.w3.org/Protocols/rfc2616/rfc2616-sec3.html#sec3.6.1 String[] chunksWithSizes = chunkedBody.split("\r\n"); boolean nextChunkIsChunkSize = true; StringBuilder finalResultMinusChunkMetadata = new StringBuilder(); for (String chunk : chunksWithSizes) { if (!nextChunkIsChunkSize) { // This is not metadata - it is actual body payload. finalResultMinusChunkMetadata.append(chunk); } // Toggle our "next is metadata" flag, as according to the RFC it should alternate between // chunk-size and chunk-data. nextChunkIsChunkSize = !nextChunkIsChunkSize; } return finalResultMinusChunkMetadata.toString(); } public static HttpHeaders extractHeadersFromRawRequestOrResponse(String rawRequestOrResponseString) { int indexOfFirstCrlf = rawRequestOrResponseString.indexOf("\r\n"); int indexOfBodySeparator = rawRequestOrResponseString.indexOf("\r\n\r\n"); if (indexOfFirstCrlf == -1 || indexOfBodySeparator == -1) { throw new IllegalArgumentException("The given rawRequestOrResponseString does not appear to be a valid HTTP message"); } String concatHeaders = rawRequestOrResponseString.substring(indexOfFirstCrlf + "\r\n".length(), indexOfBodySeparator); HttpHeaders extractedHeaders = new DefaultHttpHeaders(); for (String concatHeader : split(concatHeaders, "\r\n")) { extractedHeaders.add(substringBefore(concatHeader, HEADER_SEPARATOR).trim(), substringAfter(concatHeader, HEADER_SEPARATOR).trim()); } return extractedHeaders; } public static Map<String, List<String>> headersToMap(HttpHeaders headers) { Map<String, List<String>> result = new LinkedHashMap<>(); headers.names().forEach(headerKey -> result.put(headerKey, headers.getAll(headerKey))); return result; } }