package com.github.chhsiaoninety.nitmproxy.handler.protocol.http1; import com.github.chhsiaoninety.nitmproxy.Address; import com.github.chhsiaoninety.nitmproxy.ConnectionInfo; import com.github.chhsiaoninety.nitmproxy.NitmProxyConfig; import com.github.chhsiaoninety.nitmproxy.NitmProxyMaster; import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerAdapter; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.codec.http.*; import org.junit.After; import org.junit.Before; import org.junit.Test; import static com.github.chhsiaoninety.nitmproxy.HttpObjectUtil.*; import static io.netty.util.ReferenceCountUtil.*; import static org.junit.Assert.*; import static org.mockito.Mockito.*; public class Http1FrontendHandlerTest { private NitmProxyMaster master; private EmbeddedChannel inboundChannel; private EmbeddedChannel outboundChannel; @Before public void setUp() throws Exception { master = mock(NitmProxyMaster.class); when(master.config()).thenReturn(new NitmProxyConfig()); when(master.handler(any(), any(), any())).thenAnswer(m -> new ChannelHandlerAdapter() { }); inboundChannel = new EmbeddedChannel(); } @After public void tearDown() { inboundChannel.finishAndReleaseAll(); if (outboundChannel != null) { outboundChannel.finishAndReleaseAll(); } } @Test public void shouldTunnelRequest() { Http1FrontendHandler handler = tunneledHandler(); inboundChannel.pipeline().addLast(handler); assertFalse(inboundChannel.writeInbound(requestBytes())); assertEquals(1, outboundChannel.outboundMessages().size()); assertTrue(outboundChannel.outboundMessages().peek() instanceof FullHttpRequest); FullHttpRequest request = (FullHttpRequest) outboundChannel.outboundMessages().poll(); assertEquals(HttpMethod.GET, request.method()); assertEquals(HttpVersion.HTTP_1_1, request.protocolVersion()); assertEquals("/", request.uri()); assertEquals("localhost", request.headers().get(HttpHeaderNames.HOST)); assertEquals(0, request.content().readableBytes()); release(request); } @Test public void shouldTunnelRequests() { Http1FrontendHandler handler = tunneledHandler(); inboundChannel.pipeline().addLast(handler); assertFalse(inboundChannel.writeInbound(requestBytes())); assertEquals(1, outboundChannel.outboundMessages().size()); assertTrue(outboundChannel.outboundMessages().peek() instanceof FullHttpRequest); release(outboundChannel.outboundMessages().poll()); assertFalse(inboundChannel.writeInbound(requestBytes())); assertTrue(outboundChannel.outboundMessages().peek() instanceof FullHttpRequest); release(outboundChannel.outboundMessages().poll()); } @Test public void shouldHandleHttpProxyRequest() { Http1FrontendHandler handler = httpProxyHandler(true); inboundChannel.pipeline().addLast(handler); ByteBuf requestBytes = requestBytes(new DefaultHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.GET, "http://localhost:9000/")); assertFalse(inboundChannel.writeInbound(requestBytes)); assertEquals(1, outboundChannel.outboundMessages().size()); assertTrue(outboundChannel.outboundMessages().peek() instanceof FullHttpRequest); HttpRequest httpRequest = (HttpRequest) outboundChannel.outboundMessages().poll(); assertEquals(httpRequest.method(), HttpMethod.GET); assertEquals(httpRequest.protocolVersion(), HttpVersion.HTTP_1_1); assertEquals(httpRequest.uri(), "/"); release(httpRequest); } @Test public void shouldHandleHttpProxyRequests() { Http1FrontendHandler handler = httpProxyHandler(true); inboundChannel.pipeline().addLast(handler); ByteBuf requestBytes = requestBytes(new DefaultHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.GET, "http://localhost:9000/")); assertFalse(inboundChannel.writeInbound(requestBytes.copy())); assertEquals(1, outboundChannel.outboundMessages().size()); assertTrue(outboundChannel.outboundMessages().peek() instanceof FullHttpRequest); HttpRequest httpRequest = (HttpRequest) outboundChannel.outboundMessages().poll(); assertEquals(httpRequest.method(), HttpMethod.GET); assertEquals(httpRequest.protocolVersion(), HttpVersion.HTTP_1_1); assertEquals(httpRequest.uri(), "/"); release(httpRequest); // Second request assertFalse(inboundChannel.writeInbound(requestBytes)); assertEquals(1, outboundChannel.outboundMessages().size()); assertTrue(outboundChannel.outboundMessages().peek() instanceof FullHttpRequest); httpRequest = (HttpRequest) outboundChannel.outboundMessages().poll(); assertEquals(httpRequest.method(), HttpMethod.GET); assertEquals(httpRequest.protocolVersion(), HttpVersion.HTTP_1_1); assertEquals(httpRequest.uri(), "/"); release(httpRequest); } @Test public void shouldHandleHttpProxyCreateNewConnection() { Http1FrontendHandler handler = httpProxyHandler(true); inboundChannel.pipeline().addLast(handler); ByteBuf firstRequestBytes = requestBytes(new DefaultHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.GET, "http://localhost:8000/")); assertFalse(inboundChannel.writeInbound(firstRequestBytes)); assertEquals(1, outboundChannel.outboundMessages().size()); assertTrue(outboundChannel.outboundMessages().peek() instanceof FullHttpRequest); HttpRequest httpRequest = (HttpRequest) outboundChannel.outboundMessages().poll(); assertEquals(httpRequest.method(), HttpMethod.GET); assertEquals(httpRequest.protocolVersion(), HttpVersion.HTTP_1_1); assertEquals(httpRequest.uri(), "/"); release(httpRequest); EmbeddedChannel firstOutboundChannel = outboundChannel; // Second request ByteBuf secondRequestBytes = requestBytes(new DefaultHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.GET, "http://localhost:9000/")); assertFalse(inboundChannel.writeInbound(secondRequestBytes)); assertEquals(1, outboundChannel.outboundMessages().size()); assertTrue(outboundChannel.outboundMessages().peek() instanceof FullHttpRequest); httpRequest = (HttpRequest) outboundChannel.outboundMessages().poll(); assertEquals(httpRequest.method(), HttpMethod.GET); assertEquals(httpRequest.protocolVersion(), HttpVersion.HTTP_1_1); assertEquals(httpRequest.uri(), "/"); release(httpRequest); assertNotSame(firstOutboundChannel, outboundChannel); assertFalse(firstOutboundChannel.isActive()); } @Test public void shouldClosedWhenHttpProxyDestinationNotAvailable() { Http1FrontendHandler handler = httpProxyHandler(false); inboundChannel.pipeline().addLast(handler); ByteBuf firstRequestBytes = requestBytes(new DefaultHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.GET, "http://localhost:8000/")); assertFalse(inboundChannel.writeInbound(firstRequestBytes)); assertFalse(inboundChannel.isActive()); } @Test public void shouldCreateNewOutboundWhenOldIsInactive() { Http1FrontendHandler handler = httpProxyHandler(true); inboundChannel.pipeline().addLast(handler); ByteBuf firstRequestBytes = requestBytes(new DefaultHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.GET, "http://localhost:8000/")); assertFalse(inboundChannel.writeInbound(firstRequestBytes)); assertEquals(1, outboundChannel.outboundMessages().size()); assertTrue(outboundChannel.outboundMessages().peek() instanceof FullHttpRequest); HttpRequest httpRequest = (HttpRequest) outboundChannel.outboundMessages().poll(); assertEquals(httpRequest.method(), HttpMethod.GET); assertEquals(httpRequest.protocolVersion(), HttpVersion.HTTP_1_1); assertEquals(httpRequest.uri(), "/"); release(httpRequest); EmbeddedChannel firstOutboundChannel = outboundChannel; outboundChannel.close(); // Second request ByteBuf secondRequestBytes = requestBytes(new DefaultHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.GET, "http://localhost:8000/")); assertFalse(inboundChannel.writeInbound(secondRequestBytes)); assertEquals(1, outboundChannel.outboundMessages().size()); assertTrue(outboundChannel.outboundMessages().peek() instanceof FullHttpRequest); httpRequest = (HttpRequest) outboundChannel.outboundMessages().poll(); assertEquals(httpRequest.method(), HttpMethod.GET); assertEquals(httpRequest.protocolVersion(), HttpVersion.HTTP_1_1); assertEquals(httpRequest.uri(), "/"); release(httpRequest); assertNotSame(firstOutboundChannel, outboundChannel); } @Test public void shouldHandleConnect() { Http1FrontendHandler handler = httpProxyHandler(true); inboundChannel.pipeline().addLast(handler); ByteBuf requestBytes = requestBytes(new DefaultHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.CONNECT, "localhost:8000")); assertFalse(inboundChannel.writeInbound(requestBytes)); assertNotNull(outboundChannel); assertTrue(outboundChannel.isActive()); assertEquals(1, inboundChannel.outboundMessages().size()); assertTrue(inboundChannel.outboundMessages().peek() instanceof ByteBuf); ByteBuf respByteBuf = (ByteBuf) inboundChannel.outboundMessages().poll(); byte[] respBytes = new byte[respByteBuf.readableBytes()]; respByteBuf.readBytes(respBytes); assertEquals("HTTP/1.1 200 OK\r\n\r\n", new String(respBytes)); respByteBuf.release(); } private Http1FrontendHandler httpProxyHandler(boolean outboundAvailable) { if (outboundAvailable) { when(master.connect(any(), any(), any())).then( invocationOnMock -> { outboundChannel = new EmbeddedChannel((ChannelHandler) invocationOnMock.getArguments()[2]); return outboundChannel.newSucceededFuture(); }); } else { when(master.connect(any(), any(), any())).then( invocationOnMock -> inboundChannel.newPromise().setFailure(new Exception())); } return new Http1FrontendHandler(master, connectionInfo()); } private Http1FrontendHandler tunneledHandler() { outboundChannel = new EmbeddedChannel(); return new Http1FrontendHandler(master, connectionInfo(), outboundChannel); } private static ConnectionInfo connectionInfo() { return new ConnectionInfo( new Address("localhost", 8080), new Address("localhost", 8080)); } }