package com.nike.riposte.server.componenttest;

import com.nike.internal.util.StringUtils;
import com.nike.riposte.server.Server;
import com.nike.riposte.server.config.ServerConfig;
import com.nike.riposte.server.http.Endpoint;
import com.nike.riposte.server.http.RequestInfo;
import com.nike.riposte.server.http.ResponseInfo;
import com.nike.riposte.server.http.StandardEndpoint;
import com.nike.riposte.server.testutils.ComponentTestUtils;
import com.nike.riposte.util.Matcher;

import com.google.common.base.Charsets;
import com.google.common.hash.HashCode;
import com.google.common.hash.HashFunction;
import com.google.common.hash.Hasher;
import com.google.common.hash.Hashing;

import org.apache.commons.io.IOUtils;
import org.jetbrains.annotations.NotNull;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;

import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;

import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.multipart.FileUpload;
import io.netty.handler.codec.http.multipart.HttpData;
import io.netty.handler.codec.http.multipart.InterfaceHttpData;
import io.netty.util.CharsetUtil;

import static io.restassured.RestAssured.given;
import static java.util.Collections.singleton;
import static org.assertj.core.api.Assertions.assertThat;

public class VerifyMultipartRequestsWorkComponentTest {

    private static Server server;
    private static ServerConfig serverConfig;

    @BeforeClass
    public static void setUpClass() throws Exception {
        serverConfig = new MultipartTestConfig();
        server = new Server(serverConfig);
        server.startup();
    }

    @AfterClass
    public static void tearDown() throws Exception {
        server.shutdown();
    }

    @Test
    public void verify_multipart_file_works_properly() throws IOException, InterruptedException {
        String name = "someImageFile";
        String filename = "helloWorld.png";
        InputStream multipartFileInputStream = VerifyMultipartRequestsWorkComponentTest.class.getClassLoader().getResourceAsStream(filename);
        byte[] multipartFileBytes = IOUtils.toByteArray(multipartFileInputStream);

        String responseString =
                    given()
                        .baseUri("http://127.0.0.1")
                        .port(serverConfig.endpointsPort())
                        .basePath(MultipartTestEndpoint.MATCHING_PATH)
                        .log().all()
                    .when()
                        .multiPart(name, filename, multipartFileBytes, "image/png")
                        .post()
                    .then()
                        .log().all()
                        .statusCode(200)
                        .extract().asString();

        String expectedHash = getHashForMultipartPayload(name, filename, multipartFileBytes);
        assertThat(responseString).isEqualTo(expectedHash);
    }

    @Test
    public void verify_multipart_attribute_works_properly() throws IOException, InterruptedException {
        String name = "someAttribute";

        String uuidString = UUID.randomUUID().toString();
        String responseString =
                    given()
                        .baseUri("http://127.0.0.1")
                        .port(serverConfig.endpointsPort())
                        .basePath(MultipartTestEndpoint.MATCHING_PATH)
                        .log().all()
                    .when()
                        .multiPart(name, uuidString)
                        .post()
                    .then()
                        .log().all()
                        .statusCode(200)
                        .extract().asString();

        String expectedHash = getHashForMultipartPayload(name, null, uuidString.getBytes(CharsetUtil.UTF_8));
        assertThat(responseString).isEqualTo(expectedHash);
    }

    @Test
    public void verify_multipart_with_mixed_types_works_properly() throws IOException, InterruptedException {
        String imageName = "someImageFile";
        String imageFilename = "helloWorld.png";
        InputStream imageFileInputStream = VerifyMultipartRequestsWorkComponentTest.class.getClassLoader().getResourceAsStream(imageFilename);
        byte[] imageFileBytes = IOUtils.toByteArray(imageFileInputStream);

        String textName = "someTextFile";
        String textFilename = "testMultipartFile.txt";
        InputStream textFileInputStream = VerifyMultipartRequestsWorkComponentTest.class.getClassLoader().getResourceAsStream(textFilename);
        byte[] textFileBytes = IOUtils.toByteArray(textFileInputStream);

        String attributeName = "someAttribute";
        String attributeString = UUID.randomUUID().toString();

        String responseString =
                    given()
                        .baseUri("http://127.0.0.1")
                        .port(serverConfig.endpointsPort())
                        .basePath(MultipartTestEndpoint.MATCHING_PATH)
                        .log().all()
                    .when()
                        .multiPart(imageName, imageFilename, imageFileBytes, "image/png")
                        .multiPart(attributeName, attributeString)
                        .multiPart(textName, textFilename, textFileBytes)
                        .post()
                    .then()
                        .log().all()
                        .statusCode(200)
                        .extract().asString();

        String expectedImageFileHash = getHashForMultipartPayload(imageName, imageFilename, imageFileBytes);
        String expectedAttributeHash = getHashForMultipartPayload(attributeName, null, attributeString.getBytes(CharsetUtil.UTF_8));
        String expectedTextFileHash = getHashForMultipartPayload(textName, textFilename, textFileBytes);
        String expectedResponse = StringUtils.join(Arrays.asList(expectedImageFileHash, expectedAttributeHash, expectedTextFileHash), ",");
        assertThat(responseString).isEqualTo(expectedResponse);
    }

    private static final HashFunction hashFunction = Hashing.md5();
    private static String getHashForMultipartPayload(String name, String filename, byte[] payloadBytes) {
        Hasher hasher = hashFunction.newHasher()
                                    .putString(name, Charsets.UTF_8);
        if (filename != null)
            hasher = hasher.putString(filename, Charsets.UTF_8);

        hasher = hasher.putBytes(payloadBytes);

        HashCode hc = hasher.hash();
        return hc.toString();
    }

    public static class MultipartTestConfig implements ServerConfig {
        private final Collection<Endpoint<?>> endpoints = singleton(new MultipartTestEndpoint());
        private final int port;

        public MultipartTestConfig() {
            try {
                port = ComponentTestUtils.findFreePort();
            } catch (IOException e) {
                throw new RuntimeException("Couldn't allocate port", e);
            }
        }

        @Override
        public @NotNull Collection<@NotNull Endpoint<?>> appEndpoints() {
            return endpoints;
        }

        @Override
        public int endpointsPort() {
            return port;
        }
    }

    public static class MultipartTestEndpoint extends StandardEndpoint<String, String> {

        public static String MATCHING_PATH = "/multipart";

        @Override
        public @NotNull CompletableFuture<ResponseInfo<String>> execute(
            @NotNull RequestInfo<String> request,
            @NotNull Executor longRunningTaskExecutor,
            @NotNull ChannelHandlerContext ctx
        ) {
            List<String> hashesFound = new ArrayList<>();

            for (InterfaceHttpData multipartData : request.getMultipartParts()) {
                String name = multipartData.getName();
                byte[] payloadBytes;
                try {
                    payloadBytes = ((HttpData)multipartData).get();
                } catch (IOException e) {
                    throw new RuntimeException(e);
                }

                String filename = null;
                switch (multipartData.getHttpDataType()) {
                    case Attribute:
                        // Do nothing - filename stays null
                        break;
                    case FileUpload:
                        filename = ((FileUpload)multipartData).getFilename();
                        break;
                    default:
                        throw new RuntimeException("Unsupported multipart type: " + multipartData.getHttpDataType().name());
                }

                hashesFound.add(getHashForMultipartPayload(name, filename, payloadBytes));
            }

            return CompletableFuture.completedFuture(ResponseInfo.newBuilder(StringUtils.join(hashesFound, ",")).build());
        }

        @Override
        public @NotNull Matcher requestMatcher() {
            return Matcher.match(MATCHING_PATH, HttpMethod.POST);
        }

    }
}