/**
 * Copyright 2016-2017 Sixt GmbH & Co. Autovermietung KG
 * Licensed under the Apache License, Version 2.0 (the "License"); you may 
 * not use this file except in compliance with the License. You may obtain a 
 * copy of the License at http://www.apache.org/licenses/LICENSE-2.0
 * Unless required by applicable law or agreed to in writing, software 
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 
 * License for the specific language governing permissions and limitations 
 * under the License.
 */

package com.sixt.service.framework.jetty;

import com.codahale.metrics.MetricRegistry;
import com.google.common.primitives.Ints;
import com.google.inject.Inject;
import com.google.inject.Singleton;
import com.google.protobuf.Message;
import com.sixt.service.framework.*;
import com.sixt.service.framework.metrics.GoTimer;
import com.sixt.service.framework.protobuf.ProtobufUtil;
import com.sixt.service.framework.protobuf.RpcEnvelope;
import com.sixt.service.framework.rpc.RpcCallException;
import com.sixt.service.framework.util.ReflectionUtil;
import io.opentracing.Span;
import io.opentracing.Tracer;
import io.opentracing.tag.Tags;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.MDC;

import javax.servlet.ServletInputStream;
import javax.servlet.ServletOutputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.Map;

@Singleton
public class ProtobufHandler extends RpcHandler {

    private static final Logger logger = LoggerFactory.getLogger(ProtobufHandler.class);

    @Inject
    public ProtobufHandler(MethodHandlerDictionary handlers, MetricRegistry registry,
                           RpcHandlerMetrics handlerMetrics, ServiceProperties serviceProperties, Tracer tracer) {
        super(handlers, registry, handlerMetrics, serviceProperties, tracer);
    }

    @SuppressWarnings("unchecked")
    public void doPost(HttpServletRequest req, HttpServletResponse resp) {
        logger.debug("Handling protobuf request");

        RpcEnvelope.Request rpcRequest = null;
        String methodName = null;
        Span span = null;
        Map<String, String> headers = gatherHttpHeaders(req);
        OrangeContext context = new OrangeContext(headers);
        HttpServletRequest blubb = new HttpServletRequestWrapper(req);

        try {
            MDC.put(OrangeContext.CORRELATION_ID, context.getCorrelationId());

            ServletInputStream in = req.getInputStream();
            rpcRequest = readRpcEnvelope(in);

            methodName = rpcRequest.getServiceMethod();

            span = getSpan(methodName, headers, context);

            ServiceMethodHandler handler = handlers.getMethodHandler(methodName);
            if (handler == null) {
                incrementFailureCounter(methodName, context.getRpcOriginService(),
                        context.getRpcOriginMethod());
                throw new IllegalArgumentException("Invalid method: " +
                        rpcRequest.getServiceMethod());
            }

            Class<? extends Message> requestClass = (Class<? extends Message>)
                    ReflectionUtil.findSubClassParameterType(handler, 0);

            Message pbRequest = readRpcBody(in, requestClass);

            GoTimer methodTimer = getMethodTimer(methodName, context.getRpcOriginService(),
                    context.getRpcOriginMethod());
            long startTime = methodTimer.start();

            Message pbResponse = invokeHandlerChain(methodName, handler, pbRequest, context);

            resp.setContentType(RpcServlet.TYPE_OCTET);
            sendSuccessfulResponse(resp, rpcRequest, pbResponse);

            //TODO: should we check the response for errors?
            methodTimer.recordSuccess(startTime);
            incrementSuccessCounter(methodName, context.getRpcOriginService(),
                    context.getRpcOriginMethod());
        } catch (RpcCallException rpcEx) {
            sendErrorResponse(resp, rpcRequest, rpcEx.toString(), rpcEx.getCategory().getHttpStatus());
            if (span != null) {
                Tags.ERROR.set(span, true);
            }
            incrementFailureCounter(methodName, context.getRpcOriginService(),
                    context.getRpcOriginMethod());
        } catch (RpcReadException ex) {
            logger.warn("Bad request, cannot decode rpc message: {}", ex.toJson(req));
            sendErrorResponse(resp, rpcRequest, ex.getMessage(), HttpServletResponse.SC_BAD_REQUEST);
            if (span != null) {
                Tags.ERROR.set(span, true);
            }
            incrementFailureCounter(methodName, context.getRpcOriginService(),
                    context.getRpcOriginMethod());
        } catch (Exception ex) {
            logger.warn("Uncaught exception", ex);
            sendErrorResponse(resp, rpcRequest, ex.getMessage(), HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
            if (span != null) {
                Tags.ERROR.set(span, true);
            }
            incrementFailureCounter(methodName, context.getRpcOriginService(),
                    context.getRpcOriginMethod());
        } finally {
            if (span != null) {
                span.finish();
            }
            MDC.remove(OrangeContext.CORRELATION_ID);
        }
    }

    private void sendSuccessfulResponse(HttpServletResponse response,
                                        RpcEnvelope.Request rpcRequest,
                                        Message pbResponse) throws IOException {
        response.setStatus(HttpServletResponse.SC_OK);

        RpcEnvelope.Response rpcResponse = RpcEnvelope.Response.newBuilder().
                setServiceMethod(rpcRequest.getServiceMethod()).
                setSequenceNumber(rpcRequest.getSequenceNumber()).build();
        byte responseHeader[] = rpcResponse.toByteArray();
        byte responseBody[];
        if (pbResponse == null) {
            responseBody = new byte[0];
        } else {
            responseBody = pbResponse.toByteArray();
        }

        try {
            ServletOutputStream out = response.getOutputStream();

            out.write(Ints.toByteArray(responseHeader.length));
            out.write(responseHeader);

            out.write(Ints.toByteArray(responseBody.length));
            out.write(responseBody);
        } catch (IOException ioex) {
            //there is nothing we can do, client probably went away
            logger.debug("Caught IOException, assuming client disconnected");
        }
    }

    private void sendErrorResponse(HttpServletResponse resp,
                                   RpcEnvelope.Request rpcRequest,
                                   String message,
                                   int httpStatusCode) {
        if (rpcRequest != null) {
            try {
                if (FeatureFlags.shouldExposeErrorsToHttp(serviceProps)) {
                    resp.setStatus(httpStatusCode);
                } else {
                    resp.setStatus(HttpServletResponse.SC_OK);
                }
                if (message == null) {
                    message = "null";
                }
                RpcEnvelope.Response rpcResponse = RpcEnvelope.Response.newBuilder().
                        setServiceMethod(rpcRequest.getServiceMethod()).
                        setSequenceNumber(rpcRequest.getSequenceNumber()).
                        setError(message).build();
                byte responseHeader[] = rpcResponse.toByteArray();
                ServletOutputStream out = resp.getOutputStream();
                out.write(Ints.toByteArray(responseHeader.length));
                out.write(responseHeader);
                out.write(Ints.toByteArray(0)); //zero-length (no) body
            } catch (Exception ex) {
                logger.warn("Error writing error response", ex);
            }
        }
    }

    private RpcEnvelope.Request readRpcEnvelope(ServletInputStream in) throws Exception {
        byte chunkSize[] = new byte[4];
        in.read(chunkSize);
        int size = Ints.fromByteArray(chunkSize);
        if (size <= 0 || size > ProtobufUtil.MAX_HEADER_CHUNK_SIZE) {
            String message = "Invalid header chunk size: " + size;
            throw new RpcReadException(chunkSize, in, message);
        }
        byte headerData[] = readyFully(in, size);
        RpcEnvelope.Request rpcRequest = RpcEnvelope.Request.parseFrom(headerData);
        return rpcRequest;
    }

    private Message readRpcBody(ServletInputStream in,
                                Class<? extends Message> requestClass) throws Exception {
        byte chunkSize[] = new byte[4];
        in.read(chunkSize);
        int size = Ints.fromByteArray(chunkSize);
        if (size == 0) {
            return ProtobufUtil.newEmptyMessage(requestClass);
        }
        if (size > ProtobufUtil.MAX_BODY_CHUNK_SIZE) {
            String message = "Invalid body chunk size: " + size;
            throw new RpcReadException(chunkSize, in, message);
        }
        byte bodyData[] = readyFully(in, size);
        Message pbRequest = ProtobufUtil.byteArrayToProtobuf(bodyData, requestClass);
        return pbRequest;
    }

    private byte[] readyFully(ServletInputStream in, int totalSize) throws Exception {
        byte[] retval = new byte[totalSize];
        int bytesRead = 0;
        while (bytesRead < totalSize) {
            try {
                int read = in.read(retval, bytesRead, totalSize - bytesRead);
                if (read == -1) {
                    throw new RpcCallException(RpcCallException.Category.InternalServerError,
                            "Unable to read complete request or response");
                }
                bytesRead += read;
            } catch (IOException e) {
                throw new RpcCallException(RpcCallException.Category.InternalServerError,
                        "IOException reading data: " + e);
            }
        }
        return retval;
    }

}