package nl.knaw.huygens.timbuctoo.logging;

import com.google.common.base.Preconditions;
import com.google.common.base.Stopwatch;
import org.glassfish.jersey.message.MessageUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.MDC;

import javax.annotation.Priority;
import javax.ws.rs.container.ContainerRequestContext;
import javax.ws.rs.container.ContainerRequestFilter;
import javax.ws.rs.container.ContainerResponseContext;
import javax.ws.rs.container.ContainerResponseFilter;
import javax.ws.rs.container.PreMatching;
import javax.ws.rs.core.MultivaluedMap;
import java.io.BufferedInputStream;
import java.io.ByteArrayOutputStream;
import java.io.FilterOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.charset.Charset;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import java.util.UUID;
import java.util.concurrent.TimeUnit;

import static java.util.concurrent.TimeUnit.MILLISECONDS;

@PreMatching
@Priority(Integer.MIN_VALUE)
public final class LoggingFilter implements ContainerRequestFilter, ContainerResponseFilter {


  private static final Logger LOG = LoggerFactory.getLogger(LoggingFilter.class);
  private static final String STOPWATCH_PROPERTY = LoggingFilter.class.getName() + "stopwatch";

  private static final Comparator<Map.Entry<String, List<String>>> COMPARATOR =
    (o1, o2) -> o1.getKey().compareToIgnoreCase(o2.getKey());

  private static final String MDC_ID = "request_id";
  private static final String MDC_OUTPUT_BYTECOUNT = "output_bytecount";
  private static final String MDC_DURATION_MILLISECONDS = "duration_full_milliseconds";
  private static final String MDC_TIME_TO_FIRST_BYTE = "duration_time_to_first_byte_milliseconds";
  private static final String MDC_PRE_LOG = "request_log";
  private static final String MDC_POST_LOG = "response_log";
  private static final String MDC_HTTP_METHOD = "http_method";
  private static final String MDC_HTTP_URI = "http_uri";
  private static final String MDC_HTTP_PATH = "http_path";
  private static final String MDC_HTTP_AUTHORITY = "http_authority";
  private static final String MDC_HTTP_QUERY = "http_query";
  private static final String MDC_REQUEST_HEADERS = "http_request_headers";
  private static final String MDC_REQUEST_ENTITY = "http_request_content";
  private static final String MDC_HTTP_STATUS = "http_status";
  private static final String MDC_RESPONSE_HEADERS = "http_response_headers";
  private static final String MDC_RELEASE_HASH = "git_hash";
  private static final String MDC_OUTPUT_SNIPPET = "http_response_body";
  public static final String EMPTY = "";

  private final int entityLogSize;
  private final String releaseHash;

  public LoggingFilter(final int entityLogSize, String releaseHash) {
    this.entityLogSize = entityLogSize;
    this.releaseHash = releaseHash;
  }

  private String formatHeaders(final MultivaluedMap<String, String> headers) {
    final StringBuilder builder = new StringBuilder();
    for (final Map.Entry<String, List<String>> headerEntry : getSortedHeaders(headers.entrySet())) {
      final List<?> val = headerEntry.getValue();
      final String header = headerEntry.getKey();

      builder.append(header).append(": ");
      if (val.size() == 1) {
        builder.append(val.get(0));
      } else {
        boolean add = false;
        for (final Object s : val) {
          if (add) {
            builder.append(',');
          }
          add = true;
          builder.append(s);
        }
      }
      builder.append("\n");
    }
    return builder.toString();
  }

  private Set<Map.Entry<String, List<String>>> getSortedHeaders(final Set<Map.Entry<String, List<String>>> headers) {
    final TreeSet<Map.Entry<String, List<String>>> sortedHeaders = new TreeSet<>(COMPARATOR);
    sortedHeaders.addAll(headers);
    return sortedHeaders;
  }

  private InputStream addInboundEntityToMdc(InputStream stream, final Charset charset) throws IOException {
    final StringBuilder builder = new StringBuilder();
    if (!stream.markSupported()) {
      stream = new BufferedInputStream(stream);
    }
    stream.mark(entityLogSize + 1);
    final byte[] entity = new byte[entityLogSize + 1];
    final int entitySize = stream.read(entity);
    builder.append(new String(entity, 0, Math.min(entitySize, entityLogSize), charset));
    if (entitySize > entityLogSize) {
      builder.append(" (capped at ").append(entityLogSize).append(" bytes)");
    }
    MDC.put(MDC_REQUEST_ENTITY, builder.toString());
    stream.reset();
    return stream;
  }

  @Override
  public void filter(final ContainerRequestContext context) throws IOException {
    final Stopwatch stopwatch = Stopwatch.createStarted();
    final UUID id = UUID.randomUUID();
    MDC.put(MDC_ID, id.toString());
    MDC.put(MDC_RELEASE_HASH, releaseHash);

    MDC.put(MDC_PRE_LOG, "true");
    //Log a very minimal message. Mostly to make sure that we notice requests that never log in the response filter
    LOG.info(">     " + context.getMethod() + " " + context.getUriInfo().getRequestUri().toASCIIString());
    MDC.remove(MDC_PRE_LOG);
    context.setProperty(STOPWATCH_PROPERTY, stopwatch);

    if (context.hasEntity()) {
      context.setEntityStream(
        addInboundEntityToMdc(context.getEntityStream(), MessageUtils.getCharset(context.getMediaType()))
      );
    }
  }

  @Override
  public void filter(final ContainerRequestContext requestContext, final ContainerResponseContext responseContext)
    throws IOException {
    if (!"1".equals(requestContext.getHeaderString("Is-Healthcheck"))) {
      //Actual log is done when the response stream has finished in aroundWriteTo
      String log = "< " +
        Integer.toString(responseContext.getStatus()) + " " +
        requestContext.getMethod() + " " +
        requestContext.getUriInfo().getRequestUri().toASCIIString();

      Stopwatch stopwatch = (Stopwatch) requestContext.getProperty(STOPWATCH_PROPERTY);
      if (stopwatch == null) {
        LOG.error("Lost my stopwatch!");
      } else if (!stopwatch.isRunning()) {
        LOG.error("Stopwatch was stopped!");
        stopwatch = null;
      }

      if (responseContext.hasEntity()) {
        String contentType = responseContext.getHeaderString("Content-Type");
        boolean logResponseText = "text/plain".equals(contentType) || "application/json".equals(contentType);
        //delay logging until the responseBody has been fully written
        responseContext.setEntityStream(new LoggingOutputStream(
          responseContext.getEntityStream(),
          stopwatch,
          log,
          requestContext,
          responseContext,
          MDC.getCopyOfContextMap(),
          logResponseText
        ));
      } else {
        //log now, because the writeTo wrapper will not be called
        long duration = stopwatch != null ? stopwatch.elapsed(MILLISECONDS) : -1;
        doLog(
          log,
          0,
          duration,
          duration,
          requestContext,
          responseContext,
          MDC.getCopyOfContextMap(),
          EMPTY
        );
      }
    }
  }

  private void doLog(String log, long bytecount, long totalDuration, long timeToFirstByte,
                     ContainerRequestContext requestContext,
                     ContainerResponseContext responseContext, Map<String, String> mdcVals, String responseBody) {
    //store current MDC state somewhere
    final Map<String, String> curMdc = MDC.getCopyOfContextMap();
    clearMdc();
    MDC.setContextMap(mdcVals);
    MDC.put(MDC_POST_LOG, "true");
    MDC.put(MDC_HTTP_METHOD, requestContext.getMethod());
    MDC.put(MDC_HTTP_URI, requestContext.getUriInfo().getRequestUri().toASCIIString());
    MDC.put(MDC_HTTP_PATH, requestContext.getUriInfo().getRequestUri().getPath());
    MDC.put(MDC_HTTP_AUTHORITY, requestContext.getUriInfo().getRequestUri().getAuthority());
    MDC.put(MDC_HTTP_QUERY, requestContext.getUriInfo().getRequestUri().getQuery());
    MDC.put(MDC_REQUEST_HEADERS, formatHeaders(requestContext.getHeaders()));

    MDC.put(MDC_HTTP_STATUS, Integer.toString(responseContext.getStatus()));
    MDC.put(MDC_RESPONSE_HEADERS, formatHeaders(responseContext.getStringHeaders()));

    MDC.put(MDC_OUTPUT_BYTECOUNT, bytecount + "");
    if (responseBody.length() > 0) {
      MDC.put(MDC_OUTPUT_SNIPPET, responseBody);
    }
    String size = " (" + bytecount + " bytes)";

    MDC.put(MDC_DURATION_MILLISECONDS, totalDuration + "");
    String durationLog;
    if (totalDuration != timeToFirstByte) {
      durationLog = " (" + totalDuration + "/" + timeToFirstByte + " ms)";
    } else {
      durationLog = " (" + totalDuration + " ms)";
    }

    LOG.info(log + size + durationLog);
    clearMdc();
    /*
     * The api of MDC.getCopyOfContextMap() says it may be null, so curMdc can be. In slf4j-api 1.7.24 it sometimes it
     * will be null. In slf4j-api 1.7.12 it never appeared.
     */
    if (curMdc != null) {
      MDC.setContextMap(curMdc);
    }
  }

  private void clearMdc() {
    MDC.remove(MDC_ID);
    MDC.remove(MDC_OUTPUT_BYTECOUNT);
    MDC.remove(MDC_DURATION_MILLISECONDS);
    MDC.remove(MDC_PRE_LOG);
    MDC.remove(MDC_POST_LOG);
    MDC.remove(MDC_HTTP_METHOD);
    MDC.remove(MDC_HTTP_URI);
    MDC.remove(MDC_HTTP_PATH);
    MDC.remove(MDC_HTTP_AUTHORITY);
    MDC.remove(MDC_HTTP_QUERY);
    MDC.remove(MDC_REQUEST_HEADERS);
    MDC.remove(MDC_HTTP_STATUS);
    MDC.remove(MDC_RESPONSE_HEADERS);
    MDC.remove(MDC_RELEASE_HASH);
    MDC.remove(MDC_REQUEST_ENTITY);
  }

  private class LoggingOutputStream extends FilterOutputStream {
    public static final int MAX_RESULT_SIZE = 2048;
    private final Stopwatch stopwatch;
    private final String log;
    private final ContainerRequestContext requestContext;
    private final ContainerResponseContext responseContext;
    private final Map<String, String> contextMap;
    private final boolean logResponseText;
    private long count = 0;
    private ByteArrayOutputStream responseBody = new ByteArrayOutputStream(MAX_RESULT_SIZE);
    long firstByte = -1;

    public LoggingOutputStream(OutputStream out, Stopwatch stopwatch, String log,
                               ContainerRequestContext requestContext, ContainerResponseContext responseContext,
                               Map<String, String> contextMap, boolean logResponseText) {
      super(Preconditions.checkNotNull(out));

      this.stopwatch = stopwatch;
      this.log = log;
      this.requestContext = requestContext;
      this.responseContext = responseContext;
      this.contextMap = contextMap;
      this.logResponseText = logResponseText;
    }

    public long getCount() {
      return this.count;
    }

    public void write(byte[] bytes, int off, int len) throws IOException {
      if (firstByte == -1) {
        firstByte = stopwatch != null ? stopwatch.elapsed(MILLISECONDS) : -1;
      }
      this.out.write(bytes, off, len);
      if (logResponseText && count < MAX_RESULT_SIZE - 1) {
        int writeLen = (int) count + len;
        if (writeLen > MAX_RESULT_SIZE) {
          writeLen = MAX_RESULT_SIZE - (int) count;
        }
        responseBody.write(bytes, off, writeLen);
      }
      this.count += len;
    }

    public void write(int someByte) throws IOException {
      if (firstByte == -1) {
        firstByte = stopwatch != null ? stopwatch.elapsed(MILLISECONDS) : -1;
      }
      this.out.write(someByte);
      if (logResponseText && count < MAX_RESULT_SIZE - 1) {
        responseBody.write(someByte);
      }
      ++this.count;
    }

    public void close() throws IOException {
      this.out.close();
      doLog(log,
        count,
        stopwatch != null ? stopwatch.elapsed(MILLISECONDS) : -1,
        firstByte,
        requestContext,
        responseContext,
        contextMap,
        responseBody.toString("UTF-8")
      );
      responseBody.close();
    }
  }
}