/*
 * Copyright (c) 2012-2016 Codenvy, S.A.
 * This program and the accompanying materials are made
 * available under the terms of the Eclipse Public License 2.0
 * which is available at https://www.eclipse.org/legal/epl-2.0/
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 * Contributors:
 *   Codenvy, S.A. - initial API and implementation
 */
package org.everrest.core.servlet;

import org.everrest.core.ExtHttpHeaders;
import org.everrest.core.impl.ContainerRequest;
import org.everrest.core.impl.InputHeadersMap;
import org.everrest.core.impl.MultivaluedMapImpl;
import org.slf4j.LoggerFactory;

import javax.servlet.http.HttpServletRequest;
import javax.ws.rs.core.MultivaluedMap;
import javax.ws.rs.core.SecurityContext;
import java.io.IOException;
import java.io.InputStream;
import java.net.MalformedURLException;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URL;
import java.security.Principal;
import java.util.Enumeration;

/** @author andrew00x */
public class ServletContainerRequest extends ContainerRequest {

    private static final org.slf4j.Logger LOG = LoggerFactory.getLogger(ServletContainerRequest.class);

    public static ServletContainerRequest create(final HttpServletRequest req) {
        // If the URL is forwarded, obtain the forwarding information
        final URL forwardedUrl = getForwardedUrl(req);
        String host;
        int port;
        if (forwardedUrl == null) {
            host = req.getServerName();
            port = req.getServerPort();
        } else {
            host = forwardedUrl.getHost();
            port = forwardedUrl.getPort();
            if (port < 0) {
                port = forwardedUrl.getDefaultPort();
            }
            LOG.debug("Assuming forwarded URL: {}", forwardedUrl);
        }

        // The common URI prefix for both baseUri and requestUri
        final StringBuilder commonUriBuilder = new StringBuilder();
        final String scheme = getScheme(req);
        commonUriBuilder.append(scheme);
        commonUriBuilder.append("://");
        commonUriBuilder.append(host);
        if (!(port < 0 || (port == 80 && "http".equals(scheme)) || (port == 443 && "https".equals(scheme)))) {
            commonUriBuilder.append(':');
            commonUriBuilder.append(port);
        }
        final String commonUriPrefix = commonUriBuilder.toString();

        // The Base URI - up to the servlet path
        final StringBuilder baseUriBuilder = new StringBuilder(commonUriPrefix);
        baseUriBuilder.append(req.getContextPath());
        baseUriBuilder.append(req.getServletPath());
        final URI baseUri = URI.create(baseUriBuilder.toString());

        // The RequestURI - everything in the URL
        final StringBuilder requestUriBuilder = new StringBuilder(commonUriPrefix);
        requestUriBuilder.append(req.getRequestURI());
        final String queryString = req.getQueryString();
        if (queryString != null) {
            requestUriBuilder.append('?');
            requestUriBuilder.append(queryString);
        }
        final URI requestUri = URI.create(requestUriBuilder.toString());
        return new ServletContainerRequest(getMethod(req), requestUri, baseUri, getEntityStream(req), getHeaders(req),
                                           getSecurityContext(req));
    }

    private ServletContainerRequest(String method, URI requestUri, URI baseUri, InputStream entityStream,
                                    MultivaluedMap<String, String> httpHeaders, SecurityContext securityContext) {
        super(method, requestUri, baseUri, entityStream, httpHeaders, securityContext);
    }

    /**
     * Extract HTTP method name from servlet request.
     *
     * @param servletRequest
     *         {@link HttpServletRequest}
     * @return HTTP method name
     * @see HttpServletRequest#getMethod()
     */
    private static String getMethod(HttpServletRequest servletRequest) {
        return servletRequest.getMethod();
    }

    private static String getScheme(HttpServletRequest servletRequest) {
        return servletRequest.getScheme();
    }

    /**
     * Get the URL that is forwarded using the standard X-Forwarded-Host header.
     *
     * @param servletRequest
     * @return The URL of the forwarded host. If the header is missing or invalid, null is returned.
     */
    private static URL getForwardedUrl(HttpServletRequest servletRequest) {
        final String forwardedHostAndPort = servletRequest.getHeader(FORWARDED_HOST);
        if (forwardedHostAndPort == null || forwardedHostAndPort.isEmpty()) {
            return null;
        }
        URL url = parseForwardedHostHeader(forwardedHostAndPort, servletRequest);
        if (url == null && LOG.isWarnEnabled()) {
            LOG.warn("Ignoring invalid " + ExtHttpHeaders.FORWARDED_HOST + ": " + forwardedHostAndPort);
        }
        return url;
    }

    /** Parse according to IETF standard for Host field: http://tools.ietf.org/html/rfc7230#section-5.4 */
    private static URL parseForwardedHostHeader(String forwardedHostAndPort, HttpServletRequest servletRequest) {
        final String[] parts = forwardedHostAndPort.split(":");
        if (parts.length > 2) {
            return null;
        }
        int fwdPort = -1;
        if (parts.length == 2) {
            try {
                fwdPort = Integer.parseInt(parts[1]);
            } catch (NumberFormatException e) {
                return null;
            }
            if (fwdPort < 0) {
                return null;
            }
        }
        final String fwdHost = parts[0];
        final String scheme = getScheme(servletRequest);
        try {
            return new URI(scheme, null, fwdHost, fwdPort, null, null, null).toURL();
        } catch (URISyntaxException | MalformedURLException e) {
            LOG.debug(e.getLocalizedMessage());
        }
        return null;
    }

    /**
     * Get HTTP headers from {@link HttpServletRequest} .
     *
     * @param servletRequest
     *         {@link HttpServletRequest}
     * @return request headers
     */
    private static MultivaluedMap<String, String> getHeaders(HttpServletRequest servletRequest) {
        MultivaluedMap<String, String> h = new MultivaluedMapImpl();
        Enumeration<String> headerNames = servletRequest.getHeaderNames();
        while (headerNames.hasMoreElements()) {
            String name = headerNames.nextElement();
            Enumeration<String> e = servletRequest.getHeaders(name);
            while (e.hasMoreElements()) {
                h.add(name, e.nextElement());
            }
        }
        return new InputHeadersMap(h);
    }

    /**
     * Get input stream from {@link HttpServletRequest} .
     *
     * @param servletRequest
     *         {@link HttpServletRequest}
     * @return request stream or null
     */
    private static InputStream getEntityStream(HttpServletRequest servletRequest) {
        try {
            return servletRequest.getInputStream();
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private static SecurityContext getSecurityContext(final HttpServletRequest servletRequest) {
        return new SecurityContext() {
            @Override
            public Principal getUserPrincipal() {
                return servletRequest.getUserPrincipal();
            }

            @Override
            public boolean isUserInRole(String role) {
                return servletRequest.isUserInRole(role);
            }

            @Override
            public boolean isSecure() {
                return servletRequest.isSecure();
            }

            @Override
            public String getAuthenticationScheme() {
                return servletRequest.getAuthType();
            }
        };
    }
}