package org.pytorch.serve.http;

import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpUtil;
import io.netty.handler.codec.http.QueryStringDecoder;
import io.netty.util.CharsetUtil;
import java.io.IOException;
import java.nio.file.FileAlreadyExistsException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;
import org.apache.commons.io.FilenameUtils;
import org.pytorch.serve.archive.Manifest;
import org.pytorch.serve.archive.ModelArchive;
import org.pytorch.serve.archive.ModelException;
import org.pytorch.serve.archive.ModelNotFoundException;
import org.pytorch.serve.archive.ModelVersionNotFoundException;
import org.pytorch.serve.http.messages.RegisterModelRequest;
import org.pytorch.serve.servingsdk.ModelServerEndpoint;
import org.pytorch.serve.snapshot.SnapshotManager;
import org.pytorch.serve.util.ConfigManager;
import org.pytorch.serve.util.JsonUtils;
import org.pytorch.serve.util.NettyUtils;
import org.pytorch.serve.wlm.Model;
import org.pytorch.serve.wlm.ModelManager;
import org.pytorch.serve.wlm.WorkerThread;

/**
 * A class handling inbound HTTP requests to the management API.
 *
 * <p>This class
 */
public class ManagementRequestHandler extends HttpRequestHandlerChain {

    /** Creates a new {@code ManagementRequestHandler} instance. */
    public ManagementRequestHandler(Map<String, ModelServerEndpoint> ep) {
        endpointMap = ep;
    }

    @Override
    protected void handleRequest(
            ChannelHandlerContext ctx,
            FullHttpRequest req,
            QueryStringDecoder decoder,
            String[] segments)
            throws ModelException {
        if (isManagementReq(segments)) {
            if (endpointMap.getOrDefault(segments[1], null) != null) {
                handleCustomEndpoint(ctx, req, segments, decoder);
            } else {
                if (!"models".equals(segments[1])) {
                    throw new ResourceNotFoundException();
                }

                HttpMethod method = req.method();
                if (segments.length < 3) {
                    if (HttpMethod.GET.equals(method)) {
                        handleListModels(ctx, decoder);
                        return;
                    } else if (HttpMethod.POST.equals(method)) {
                        handleRegisterModel(ctx, decoder, req);
                        return;
                    }
                    throw new MethodNotAllowedException();
                }

                String modelVersion = null;
                if (segments.length == 4) {
                    modelVersion = segments[3];
                }
                if (HttpMethod.GET.equals(method)) {
                    handleDescribeModel(ctx, segments[2], modelVersion);
                } else if (HttpMethod.PUT.equals(method)) {
                    if (segments.length == 5 && "set-default".equals(segments[4])) {
                        setDefaultModelVersion(ctx, segments[2], segments[3]);
                    } else {
                        handleScaleModel(ctx, decoder, segments[2], modelVersion);
                    }
                } else if (HttpMethod.DELETE.equals(method)) {
                    handleUnregisterModel(ctx, segments[2], modelVersion);
                } else {
                    throw new MethodNotAllowedException();
                }
            }
        } else {
            chain.handleRequest(ctx, req, decoder, segments);
        }
    }

    private boolean isManagementReq(String[] segments) {
        return segments.length == 0
                || ((segments.length >= 2 && segments.length <= 4) && segments[1].equals("models"))
                || (segments.length == 5 && "set-default".equals(segments[4]))
                || endpointMap.containsKey(segments[1]);
    }

    private void handleListModels(ChannelHandlerContext ctx, QueryStringDecoder decoder) {
        int limit = NettyUtils.getIntParameter(decoder, "limit", 100);
        int pageToken = NettyUtils.getIntParameter(decoder, "next_page_token", 0);
        if (limit > 100 || limit < 0) {
            limit = 100;
        }
        if (pageToken < 0) {
            pageToken = 0;
        }

        ModelManager modelManager = ModelManager.getInstance();
        Map<String, Model> models = modelManager.getDefaultModels();

        List<String> keys = new ArrayList<>(models.keySet());
        Collections.sort(keys);
        ListModelsResponse list = new ListModelsResponse();

        int last = pageToken + limit;
        if (last > keys.size()) {
            last = keys.size();
        } else {
            list.setNextPageToken(String.valueOf(last));
        }

        for (int i = pageToken; i < last; ++i) {
            String modelName = keys.get(i);
            Model model = models.get(modelName);
            list.addModel(modelName, model.getModelUrl());
        }

        NettyUtils.sendJsonResponse(ctx, list);
    }

    private void handleDescribeModel(
            ChannelHandlerContext ctx, String modelName, String modelVersion)
            throws ModelNotFoundException, ModelVersionNotFoundException {
        ModelManager modelManager = ModelManager.getInstance();
        ArrayList<DescribeModelResponse> resp = new ArrayList<DescribeModelResponse>();

        if ("all".equals(modelVersion)) {
            for (Map.Entry<String, Model> m : modelManager.getAllModelVersions(modelName)) {
                resp.add(createModelResponse(modelManager, modelName, m.getValue()));
            }
        } else {
            Model model = modelManager.getModel(modelName, modelVersion);
            if (model == null) {
                throw new ModelNotFoundException("Model not found: " + modelName);
            }
            resp.add(createModelResponse(modelManager, modelName, model));
        }

        NettyUtils.sendJsonResponse(ctx, resp);
    }

    private DescribeModelResponse createModelResponse(
            ModelManager modelManager, String modelName, Model model) {
        DescribeModelResponse resp = new DescribeModelResponse();
        resp.setModelName(modelName);
        resp.setModelUrl(model.getModelUrl());
        resp.setBatchSize(model.getBatchSize());
        resp.setMaxBatchDelay(model.getMaxBatchDelay());
        resp.setMaxWorkers(model.getMaxWorkers());
        resp.setMinWorkers(model.getMinWorkers());
        resp.setLoadedAtStartup(modelManager.getStartupModels().contains(modelName));
        Manifest manifest = model.getModelArchive().getManifest();
        Manifest.Engine engine = manifest.getEngine();
        if (engine != null) {
            resp.setEngine(engine.getEngineName());
        }
        resp.setModelVersion(manifest.getModel().getModelVersion());
        resp.setRuntime(manifest.getRuntime().getValue());

        List<WorkerThread> workers = modelManager.getWorkers(model.getModelVersionName());
        for (WorkerThread worker : workers) {
            String workerId = worker.getWorkerId();
            long startTime = worker.getStartTime();
            boolean isRunning = worker.isRunning();
            int gpuId = worker.getGpuId();
            long memory = worker.getMemory();
            resp.addWorker(workerId, startTime, isRunning, gpuId, memory);
        }

        return resp;
    }

    private void handleRegisterModel(
            ChannelHandlerContext ctx, QueryStringDecoder decoder, FullHttpRequest req)
            throws ModelException {
        RegisterModelRequest registerModelRequest = parseRequest(req, decoder);
        String modelUrl = registerModelRequest.getModelUrl();
        if (modelUrl == null) {
            throw new BadRequestException("Parameter url is required.");
        }

        String modelName = registerModelRequest.getModelName();
        String runtime = registerModelRequest.getRuntime();
        String handler = registerModelRequest.getHandler();
        int batchSize = registerModelRequest.getBatchSize();
        int maxBatchDelay = registerModelRequest.getMaxBatchDelay();
        int initialWorkers = registerModelRequest.getInitialWorkers();
        boolean synchronous = registerModelRequest.getSynchronous();
        int responseTimeout = registerModelRequest.getResponseTimeout();
        if (responseTimeout == -1) {
            responseTimeout = ConfigManager.getInstance().getDefaultResponseTimeout();
        }
        Manifest.RuntimeType runtimeType = null;
        if (runtime != null) {
            try {
                runtimeType = Manifest.RuntimeType.fromValue(runtime);
            } catch (IllegalArgumentException e) {
                throw new BadRequestException(e);
            }
        }

        ModelManager modelManager = ModelManager.getInstance();
        final ModelArchive archive;
        try {

            archive =
                    modelManager.registerModel(
                            modelUrl,
                            modelName,
                            runtimeType,
                            handler,
                            batchSize,
                            maxBatchDelay,
                            responseTimeout,
                            null);
        } catch (FileAlreadyExistsException e) {
            throw new InternalServerException(
                    "Model file already exists " + FilenameUtils.getName(modelUrl), e);
        } catch (IOException e) {
            throw new InternalServerException("Failed to save model: " + modelUrl, e);
        }

        modelName = archive.getModelName();

        if (initialWorkers <= 0) {
            final String msg =
                    "Model \""
                            + modelName
                            + "\" Version: "
                            + archive.getModelVersion()
                            + " registered with 0 initial workers. Use scale workers API to add workers for the model.";
            SnapshotManager.getInstance().saveSnapshot();
            NettyUtils.sendJsonResponse(ctx, new StatusResponse(msg));
            return;
        }

        updateModelWorkers(
                ctx,
                modelName,
                archive.getModelVersion(),
                initialWorkers,
                initialWorkers,
                synchronous,
                true,
                f -> {
                    modelManager.unregisterModel(archive.getModelName(), archive.getModelVersion());
                    return null;
                });
    }

    private void handleUnregisterModel(
            ChannelHandlerContext ctx, String modelName, String modelVersion)
            throws ModelNotFoundException, InternalServerException, RequestTimeoutException,
                    ModelVersionNotFoundException {
        ModelManager modelManager = ModelManager.getInstance();
        HttpResponseStatus httpResponseStatus =
                modelManager.unregisterModel(modelName, modelVersion);
        if (httpResponseStatus == HttpResponseStatus.NOT_FOUND) {
            throw new ModelNotFoundException("Model not found: " + modelName);
        } else if (httpResponseStatus == HttpResponseStatus.BAD_REQUEST) {
            throw new ModelVersionNotFoundException(
                    String.format(
                            "Model version: %s does not exist for model: %s",
                            modelVersion, modelName));
        } else if (httpResponseStatus == HttpResponseStatus.INTERNAL_SERVER_ERROR) {
            throw new InternalServerException("Interrupted while cleaning resources: " + modelName);
        } else if (httpResponseStatus == HttpResponseStatus.REQUEST_TIMEOUT) {
            throw new RequestTimeoutException("Timed out while cleaning resources: " + modelName);
        } else if (httpResponseStatus == HttpResponseStatus.FORBIDDEN) {
            throw new InvalidModelVersionException(
                    "Cannot remove default version for model " + modelName);
        }
        String msg = "Model \"" + modelName + "\" unregistered";
        NettyUtils.sendJsonResponse(ctx, new StatusResponse(msg));
    }

    private void handleScaleModel(
            ChannelHandlerContext ctx,
            QueryStringDecoder decoder,
            String modelName,
            String modelVersion)
            throws ModelNotFoundException, ModelVersionNotFoundException {
        int minWorkers = NettyUtils.getIntParameter(decoder, "min_worker", 1);
        int maxWorkers = NettyUtils.getIntParameter(decoder, "max_worker", minWorkers);
        if (modelVersion == null) {
            modelVersion = NettyUtils.getParameter(decoder, "model_version", null);
        }
        if (maxWorkers < minWorkers) {
            throw new BadRequestException("max_worker cannot be less than min_worker.");
        }
        boolean synchronous =
                Boolean.parseBoolean(NettyUtils.getParameter(decoder, "synchronous", null));

        ModelManager modelManager = ModelManager.getInstance();
        if (!modelManager.getDefaultModels().containsKey(modelName)) {
            throw new ModelNotFoundException("Model not found: " + modelName);
        }
        updateModelWorkers(
                ctx, modelName, modelVersion, minWorkers, maxWorkers, synchronous, false, null);
    }

    private void updateModelWorkers(
            final ChannelHandlerContext ctx,
            final String modelName,
            final String modelVersion,
            int minWorkers,
            int maxWorkers,
            boolean synchronous,
            boolean isInit,
            final Function<Void, Void> onError)
            throws ModelVersionNotFoundException {
        ModelManager modelManager = ModelManager.getInstance();
        CompletableFuture<HttpResponseStatus> future =
                modelManager.updateModel(modelName, modelVersion, minWorkers, maxWorkers);
        if (!synchronous) {
            NettyUtils.sendJsonResponse(
                    ctx,
                    new StatusResponse("Processing worker updates..."),
                    HttpResponseStatus.ACCEPTED);
            return;
        }
        future.thenApply(
                        v -> {
                            boolean status =
                                    modelManager.scaleRequestStatus(modelName, modelVersion);
                            if (HttpResponseStatus.OK.equals(v)) {
                                if (status) {
                                    String msg =
                                            "Workers scaled to "
                                                    + minWorkers
                                                    + " for model: "
                                                    + modelName;
                                    if (modelVersion != null) {
                                        msg += ", version: " + modelVersion; // NOPMD
                                    }

                                    if (isInit) {
                                        msg =
                                                "Model \""
                                                        + modelName
                                                        + "\" Version: "
                                                        + modelVersion
                                                        + " registered with "
                                                        + minWorkers
                                                        + " initial workers";
                                    }

                                    NettyUtils.sendJsonResponse(ctx, new StatusResponse(msg), v);
                                } else {
                                    NettyUtils.sendJsonResponse(
                                            ctx,
                                            new StatusResponse("Workers scaling in progress..."),
                                            new HttpResponseStatus(210, "Partial Success"));
                                }
                            } else {
                                NettyUtils.sendError(
                                        ctx,
                                        v,
                                        new InternalServerException("Failed to start workers"));
                                if (onError != null) {
                                    onError.apply(null);
                                }
                            }
                            return v;
                        })
                .exceptionally(
                        (e) -> {
                            if (onError != null) {
                                onError.apply(null);
                            }
                            NettyUtils.sendError(ctx, HttpResponseStatus.INTERNAL_SERVER_ERROR, e);
                            return null;
                        });
    }

    private RegisterModelRequest parseRequest(FullHttpRequest req, QueryStringDecoder decoder) {
        RegisterModelRequest in;
        CharSequence mime = HttpUtil.getMimeType(req);
        if (HttpHeaderValues.APPLICATION_JSON.contentEqualsIgnoreCase(mime)) {
            in =
                    JsonUtils.GSON.fromJson(
                            req.content().toString(CharsetUtil.UTF_8), RegisterModelRequest.class);
        } else {
            in = new RegisterModelRequest(decoder);
        }
        return in;
    }

    private void setDefaultModelVersion(
            ChannelHandlerContext ctx, String modelName, String newModelVersion)
            throws ModelNotFoundException, InternalServerException, RequestTimeoutException,
                    ModelVersionNotFoundException {
        ModelManager modelManager = ModelManager.getInstance();
        HttpResponseStatus httpResponseStatus =
                modelManager.setDefaultVersion(modelName, newModelVersion);
        if (httpResponseStatus == HttpResponseStatus.NOT_FOUND) {
            throw new ModelNotFoundException("Model not found: " + modelName);
        } else if (httpResponseStatus == HttpResponseStatus.FORBIDDEN) {
            throw new ModelVersionNotFoundException(
                    "Model version " + newModelVersion + " does not exist for model " + modelName);
        }
        String msg =
                "Default vesion succsesfully updated for model \""
                        + modelName
                        + "\" to \""
                        + newModelVersion
                        + "\"";
        SnapshotManager.getInstance().saveSnapshot();
        NettyUtils.sendJsonResponse(ctx, new StatusResponse(msg));
    }
}