package com.nike.riposte.server.componenttest;

import com.nike.riposte.server.Server;
import com.nike.riposte.server.config.ServerConfig;
import com.nike.riposte.server.http.Endpoint;
import com.nike.riposte.server.http.RequestInfo;
import com.nike.riposte.server.http.ResponseInfo;
import com.nike.riposte.server.http.StandardEndpoint;
import com.nike.riposte.server.testutils.ComponentTestUtils;
import com.nike.riposte.server.testutils.ComponentTestUtils.NettyHttpClientRequestBuilder;
import com.nike.riposte.server.testutils.ComponentTestUtils.NettyHttpClientResponse;
import com.nike.riposte.util.Matcher;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;

import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;

import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.restassured.response.ExtractableResponse;

import static com.nike.riposte.server.componenttest.VerifyRequestSizeValidationComponentTest.RequestSizeValidationConfig.GLOBAL_MAX_REQUEST_SIZE;
import static com.nike.riposte.server.testutils.ComponentTestUtils.request;
import static io.netty.handler.codec.http.HttpHeaders.Values.CHUNKED;
import static io.restassured.RestAssured.given;
import static org.assertj.core.api.Assertions.assertThat;

public class VerifyRequestSizeValidationComponentTest {

    private static final String BASE_URI = "http://127.0.0.1";
    private static Server server;
    private static ServerConfig serverConfig;
    private static ObjectMapper objectMapper;
    private int incompleteCallTimeoutMillis = 2000;

    @BeforeClass
    public static void setUpClass() throws Exception {
        objectMapper = new ObjectMapper();
        serverConfig = new RequestSizeValidationConfig();
        server = new Server(serverConfig);
        server.startup();
    }

    @AfterClass
    public static void tearDown() throws Exception {
        server.shutdown();
    }

    @Test
    public void should_return_bad_request_when_ContentLength_header_exceeds_global_configured_max_request_size() throws IOException {
        ExtractableResponse response =
                given()
                        .baseUri(BASE_URI)
                        .port(serverConfig.endpointsPort())
                        .basePath(BasicEndpoint.MATCHING_PATH)
                        .log().all()
                        .body(generatePayloadOfSizeInBytes(GLOBAL_MAX_REQUEST_SIZE + 1))
                        .when()
                        .post()
                        .then()
                        .log().headers()
                        .extract();

        assertThat(response.statusCode()).isEqualTo(HttpResponseStatus.BAD_REQUEST.code());
        assertBadRequestErrorMessageAndMetadata(response.asString());
    }

    @Test
    public void should_return_expected_response_when_ContentLength_header_not_exceeding_global_request_size() {
        ExtractableResponse response =
                given()
                        .baseUri(BASE_URI)
                        .port(serverConfig.endpointsPort())
                        .basePath(BasicEndpoint.MATCHING_PATH)
                        .log().all()
                        .body(generatePayloadOfSizeInBytes(GLOBAL_MAX_REQUEST_SIZE))
                        .when()
                        .post()
                        .then()
                        .log().headers()
                        .extract();

        assertThat(response.statusCode()).isEqualTo(HttpResponseStatus.OK.code());
        assertThat(response.asString()).isEqualTo(BasicEndpoint.RESPONSE_PAYLOAD);
    }

    @Test
    public void should_return_bad_request_when_ContentLength_header_exceeds_endpoint_overridden_configured_max_request_size() throws IOException {
        ExtractableResponse response =
                given()
                        .baseUri(BASE_URI)
                        .port(serverConfig.endpointsPort())
                        .basePath(BasicEndpointWithRequestSizeValidationOverride.MATCHING_PATH)
                        .log().all()
                        .body(generatePayloadOfSizeInBytes(BasicEndpointWithRequestSizeValidationOverride.MAX_REQUEST_SIZE + 1))
                        .when()
                        .post()
                        .then()
                        .log().headers()
                        .extract();

        assertThat(response.statusCode()).isEqualTo(HttpResponseStatus.BAD_REQUEST.code());
        assertBadRequestErrorMessageAndMetadata(response.asString());
    }

    @Test
    public void should_return_expected_response_when_ContentLength_header_not_exceeding_endpoint_overridden_request_size() {
        ExtractableResponse response =
                given()
                        .baseUri(BASE_URI)
                        .port(serverConfig.endpointsPort())
                        .basePath(BasicEndpointWithRequestSizeValidationOverride.MATCHING_PATH)
                        .log().all()
                        .body(generatePayloadOfSizeInBytes(BasicEndpointWithRequestSizeValidationOverride.MAX_REQUEST_SIZE))
                        .when()
                        .post()
                        .then()
                        .log().headers()
                        .extract();

        assertThat(response.statusCode()).isEqualTo(HttpResponseStatus.OK.code());
        assertThat(response.asString()).isEqualTo(BasicEndpointWithRequestSizeValidationOverride.RESPONSE_PAYLOAD);
    }

    @Test
    public void should_return_expected_response_when_endpoint_disabled_ContentLength_header_above_global_size_validation() {
        ExtractableResponse response =
                given()
                        .baseUri(BASE_URI)
                        .port(serverConfig.endpointsPort())
                        .basePath(BasicEndpointWithRequestSizeValidationDisabled.MATCHING_PATH)
                        .log().all()
                        .body(generatePayloadOfSizeInBytes(GLOBAL_MAX_REQUEST_SIZE + 100))
                        .when()
                        .post()
                        .then()
                        .log().headers()
                        .extract();

        assertThat(response.statusCode()).isEqualTo(HttpResponseStatus.OK.code());
        assertThat(response.asString()).isEqualTo(BasicEndpointWithRequestSizeValidationDisabled.RESPONSE_PAYLOAD);
    }

    @Test
    public void should_return_bad_request_when_chunked_request_exceeds_global_configured_max_request_size() throws Exception {
        NettyHttpClientRequestBuilder request = request()
            .withMethod(HttpMethod.POST)
            .withUri(BasicEndpoint.MATCHING_PATH)
            .withPaylod(generatePayloadOfSizeInBytes(GLOBAL_MAX_REQUEST_SIZE + 1))
            .withHeader(HttpHeaders.Names.TRANSFER_ENCODING, CHUNKED);

        // when
        NettyHttpClientResponse serverResponse = request.execute(serverConfig.endpointsPort(),
                                                                incompleteCallTimeoutMillis);

        // then
        assertThat(serverResponse.statusCode).isEqualTo(HttpResponseStatus.BAD_REQUEST.code());
        assertBadRequestErrorMessageAndMetadata(serverResponse.payload);
    }

    @Test
    public void should_return_expected_response_when_chunked_request_not_exceeding_global_request_size() throws Exception {
        NettyHttpClientRequestBuilder request = request()
            .withMethod(HttpMethod.POST)
            .withUri(BasicEndpoint.MATCHING_PATH)
            .withPaylod(generatePayloadOfSizeInBytes(GLOBAL_MAX_REQUEST_SIZE))
            .withHeader(HttpHeaders.Names.TRANSFER_ENCODING, CHUNKED);

        // when
        NettyHttpClientResponse serverResponse = request.execute(serverConfig.endpointsPort(),
                                                                incompleteCallTimeoutMillis);

        // then
        assertThat(serverResponse.statusCode).isEqualTo(HttpResponseStatus.OK.code());
        assertThat(serverResponse.payload).isEqualTo(BasicEndpoint.RESPONSE_PAYLOAD);
    }

    @Test
    public void should_return_bad_request_when_chunked_request_exceeds_endpoint_overridden_configured_max_request_size() throws Exception {
        NettyHttpClientRequestBuilder request = request()
            .withMethod(HttpMethod.POST)
            .withUri(BasicEndpointWithRequestSizeValidationOverride.MATCHING_PATH)
            .withPaylod(generatePayloadOfSizeInBytes(BasicEndpointWithRequestSizeValidationOverride.MAX_REQUEST_SIZE + 1))
            .withHeader(HttpHeaders.Names.TRANSFER_ENCODING, CHUNKED);

        // when
        NettyHttpClientResponse serverResponse = request.execute(serverConfig.endpointsPort(),
                                                                incompleteCallTimeoutMillis);

        // then
        assertThat(serverResponse.statusCode).isEqualTo(HttpResponseStatus.BAD_REQUEST.code());
        assertBadRequestErrorMessageAndMetadata(serverResponse.payload);
    }

    @Test
    public void should_return_expected_response_when_chunked_request_not_exceeding_endpoint_overridden_request_size() throws Exception {
        NettyHttpClientRequestBuilder request = request()
            .withMethod(HttpMethod.POST)
            .withUri(BasicEndpointWithRequestSizeValidationOverride.MATCHING_PATH)
            .withPaylod(generatePayloadOfSizeInBytes(BasicEndpointWithRequestSizeValidationOverride.MAX_REQUEST_SIZE))
            .withHeader(HttpHeaders.Names.TRANSFER_ENCODING, CHUNKED);

        // when
        NettyHttpClientResponse serverResponse = request.execute(serverConfig.endpointsPort(),
                                                                incompleteCallTimeoutMillis);

        // then
        assertThat(serverResponse.statusCode).isEqualTo(HttpResponseStatus.OK.code());
        assertThat(serverResponse.payload).isEqualTo(BasicEndpointWithRequestSizeValidationOverride.RESPONSE_PAYLOAD);
    }

    @Test
    public void should_return_expected_response_when_endpoint_disabled_chunked_request_size_validation() throws Exception {
        NettyHttpClientRequestBuilder request = request()
            .withMethod(HttpMethod.POST)
            .withUri(BasicEndpointWithRequestSizeValidationDisabled.MATCHING_PATH)
            .withPaylod(generatePayloadOfSizeInBytes(GLOBAL_MAX_REQUEST_SIZE + 100))
            .withHeader(HttpHeaders.Names.TRANSFER_ENCODING, CHUNKED);

        // when
        NettyHttpClientResponse serverResponse = request.execute(serverConfig.endpointsPort(),
                                                                incompleteCallTimeoutMillis);

        // then
        assertThat(serverResponse.statusCode).isEqualTo(HttpResponseStatus.OK.code());
        assertThat(serverResponse.payload).isEqualTo(BasicEndpointWithRequestSizeValidationDisabled.RESPONSE_PAYLOAD);
    }

    private void assertBadRequestErrorMessageAndMetadata(String response) throws IOException {
        JsonNode error = objectMapper.readValue(response, JsonNode.class).get("errors").get(0);
        assertThat(error.get("message").textValue()).isEqualTo("Malformed request");
        assertThat(error.get("metadata").get("cause").textValue())
                .isEqualTo("The request exceeded the maximum payload size allowed");
    }

    private static String generatePayloadOfSizeInBytes(int length) {
        StringBuilder sb = new StringBuilder(length);
        for(int i = 0; i < length; i++) {
            sb.append(i % 10);
        }
        return sb.toString();
    }

    private static class BasicEndpoint extends StandardEndpoint<Void, String> {

        public static final String MATCHING_PATH = "/basicEndpoint";
        public static final String RESPONSE_PAYLOAD = "basic-endpoint-" + UUID.randomUUID().toString();

        @Override
        public @NotNull CompletableFuture<ResponseInfo<String>> execute(
            @NotNull RequestInfo<Void> request,
            @NotNull Executor longRunningTaskExecutor,
            @NotNull ChannelHandlerContext ctx
        ) {
            return CompletableFuture.completedFuture(
                    ResponseInfo.newBuilder(RESPONSE_PAYLOAD).build()
            );
        }

        @Override
        public @NotNull Matcher requestMatcher() {
            return Matcher.match(MATCHING_PATH, HttpMethod.POST);
        }
    }

    private static class BasicEndpointWithRequestSizeValidationOverride extends StandardEndpoint<Void, String> {

        public static final String MATCHING_PATH = "/basicEndpointWithOverride";
        public static final String RESPONSE_PAYLOAD = "basic-endpoint-" + UUID.randomUUID().toString();
        public static Integer MAX_REQUEST_SIZE = 10;

        @Override
        public @NotNull CompletableFuture<ResponseInfo<String>> execute(
            @NotNull RequestInfo<Void> request,
            @NotNull Executor longRunningTaskExecutor,
            @NotNull ChannelHandlerContext ctx
        ) {
            return CompletableFuture.completedFuture(
                    ResponseInfo.newBuilder(RESPONSE_PAYLOAD).build()
            );
        }

        @Override
        public @NotNull Matcher requestMatcher() {
            return Matcher.match(MATCHING_PATH, HttpMethod.POST);
        }

        @Override
        public @Nullable Integer maxRequestSizeInBytesOverride() {
            return MAX_REQUEST_SIZE;
        }
    }

    private static class BasicEndpointWithRequestSizeValidationDisabled extends StandardEndpoint<Void, String> {

        public static final String MATCHING_PATH = "/basicEndpointWithRequestSizeValidationDisabled";
        public static final String RESPONSE_PAYLOAD = "basic-endpoint-" + UUID.randomUUID().toString();
        public static Integer MAX_REQUEST_SIZE = 0;

        @Override
        public @NotNull CompletableFuture<ResponseInfo<String>> execute(
            @NotNull RequestInfo<Void> request,
            @NotNull Executor longRunningTaskExecutor,
            @NotNull ChannelHandlerContext ctx
        ) {
            return CompletableFuture.completedFuture(
                    ResponseInfo.newBuilder(RESPONSE_PAYLOAD).build()
            );
        }

        @Override
        public @NotNull Matcher requestMatcher() {
            return Matcher.match(MATCHING_PATH, HttpMethod.POST);
        }

        @Override
        public @Nullable Integer maxRequestSizeInBytesOverride() {
            return MAX_REQUEST_SIZE;
        }
    }

    public static class RequestSizeValidationConfig implements ServerConfig {
        private final Collection<Endpoint<?>> endpoints = Arrays.asList(new BasicEndpoint(),
                new BasicEndpointWithRequestSizeValidationOverride(),
                new BasicEndpointWithRequestSizeValidationDisabled());

        private final int port;
        public static int GLOBAL_MAX_REQUEST_SIZE = 5;

        public RequestSizeValidationConfig() {
            try {
                port = ComponentTestUtils.findFreePort();
            } catch (IOException e) {
                throw new RuntimeException("Couldn't allocate port", e);
            }
        }

        @Override
        public int maxRequestSizeInBytes() {
            return GLOBAL_MAX_REQUEST_SIZE;
        }

        @Override
        public @NotNull Collection<@NotNull Endpoint<?>> appEndpoints() {
            return endpoints;
        }

        @Override
        public int endpointsPort() {
            return port;
        }
    }

}