/**
 * The MIT License
 * Copyright (c) 2015 Population Register Centre
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
package fi.vm.kapa.identification.shibboleth.extauthn;

import fi.vm.kapa.identification.adapter.service.AuthenticationHandlerService;
import fi.vm.kapa.identification.adapter.service.RequestVerificationService;
import fi.vm.kapa.identification.service.UrlParamService;
import fi.vm.kapa.identification.tupas.TupasContext;
import fi.vm.kapa.identification.tupas.TupasIdentification;
import net.shibboleth.idp.authn.ExternalAuthentication;
import net.shibboleth.idp.authn.ExternalAuthenticationException;
import net.shibboleth.idp.authn.context.AuthenticationContext;
import net.shibboleth.idp.authn.context.RequestedPrincipalContext;
import net.shibboleth.idp.saml.authn.principal.AuthnContextDeclRefPrincipal;
import org.apache.commons.lang.StringUtils;
import org.apache.http.NameValuePair;
import org.apache.http.client.utils.URLEncodedUtils;
import org.opensaml.profile.context.ProfileRequestContext;
import org.opensaml.saml.saml2.core.AuthnRequest;
import org.opensaml.saml.saml2.core.Extensions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.beans.factory.config.AutowireCapableBeanFactory;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.context.support.WebApplicationContextUtils;

import javax.servlet.ServletConfig;
import javax.servlet.ServletException;
import javax.servlet.annotation.WebServlet;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import javax.ws.rs.core.MultivaluedHashMap;
import javax.ws.rs.core.MultivaluedMap;

import java.io.IOException;
import java.nio.charset.Charset;
import java.security.Principal;
import java.util.*;

@WebServlet(name="ShibbolethExtAuthnHandler", urlPatterns={"/authn/External/*"})
public class ShibbolethExtAuthnHandler extends HttpServlet {

    private final static Logger logger = LoggerFactory.getLogger(ShibbolethExtAuthnHandler.class);

    @Value("${default.error.url}")
    private String defaultErrorBase;

    /* These strings define the error redirect URL query parameter that can be
     * used to guide the error page, the value matches the property key that
     * fetches the correct language variant for the error message
     */
    @Value("failure.param.idp.ext")
    private String errorParamIdpExt;
    @Value("failure.param.entityid")
    private String errorParamInvalidEID;
    @Value("failure.param.authnreq")
    private String errorParamInvalidAuthenticationRequest;

    @Autowired
    private AuthenticationHandlerService authenticationHandlerService;

    @Autowired
    private RequestVerificationService securityService;

    public void init(ServletConfig config) throws ServletException {
        try {
            WebApplicationContext springContext = WebApplicationContextUtils.getRequiredWebApplicationContext(config.getServletContext());
            final AutowireCapableBeanFactory beanFactory = springContext.getAutowireCapableBeanFactory();
            beanFactory.autowireBean(this);
        }
        catch (Exception e) {
            logger.error("Error initializing ShibbolethExtAuthnHandler", e);
        }
    }

    /** This method executes the identity building process in two parts. It first initialises
     * a new identity session and sends a redirect to the adapter component.
     * The adapter redirects back here after finishing its own session building process.
     * The actual payload from the adapter is not included in the redirect, it is later fetched
     * by the IdP via a direct connection to the adapter using the token ID (token) as reference.
     *
     * The process is as follows:
     *
     * 1. The IdP redirects here to initialise the external authentication process.
     *    Returns the form that posts tupas request to the bank
     * 2. User may cancel authentication in the bank
     * 3. Once the bank authentication is finished, it redirects back here
     *    with token and tupas response.
     */
    @Override
    public void doGet(HttpServletRequest request, HttpServletResponse response) throws IOException {
        String redirectUrl;
        try {
            if (StringUtils.isBlank(request.getParameter("token"))) {
                // If token ID is empty, this is an initial call from IdP (1.)
                initializeTupasIdentification(request, response);
                return;
            } else if ("cancel".equalsIgnoreCase(request.getParameter("msg"))) {
                // User cancelled in bank (2.)
                String token = request.getParameter("token");
                logger.warn("Cancel request from bank: token: " + token );
                MultivaluedMap<String, String> sessionParams = authenticationHandlerService.purgeSession(token);
                request.setAttribute(ExternalAuthentication.AUTHENTICATION_ERROR_KEY,"User canceled authentication");
                ExternalAuthentication.finishExternalAuthentication(sessionParams.getFirst("ckey"), request, response);
                return;
            } else {
                // User redirected with identification (3.)
                verifyTupasIdentification(request, response);
                return;
            }
        }
        catch (ExternalAuthenticationException eae) {
            logger.warn("Failure on external authentication", eae.getMessage());
            redirectUrl = createErrorURL(errorParamIdpExt);
            response.sendRedirect(redirectUrl);
        }
    }

    @Override
    public void doPost(HttpServletRequest request, HttpServletResponse response) throws IOException {
        // User redirected with identification (3.)
        try {
            verifyTupasIdentification(request, response);
        } catch (ExternalAuthenticationException e) {
            logger.warn("Failure on external authentication", e.getMessage());
            response.sendRedirect(createErrorURL(errorParamIdpExt));
        }
        return;
    }


    private void initializeTupasIdentification(HttpServletRequest request, HttpServletResponse response) throws ExternalAuthenticationException, IOException {
        String convKey = ExternalAuthentication.startExternalAuthentication(request);
        /* The above method enriches the servlet request with Shibboleth IdP related
         * data such as the relying party value which is fetched in the code below.
         * Note that this data is not fetched from browser generated request but directly
         * from Shibboleth through internal class method call
         */
        HttpSession session = request.getSession();
        if (session == null) {
            response.sendRedirect(createErrorURL(errorParamIdpExt));
            return;
        }
        String sessionId = session.getId();
        String relyingParty = String.valueOf(request.getAttribute(ExternalAuthentication.RELYING_PARTY_PARAM));

        // Relying party parameter must match the allowed entity ID format
        if (!UrlParamService.isValidEntityId(relyingParty)) {
            logger.warn("<<{}>> Received invalid relying party");
            response.sendRedirect(createErrorURL(errorParamInvalidEID));
            return;
        } else {
            ProfileRequestContext prc = ExternalAuthentication.getProfileRequestContext(convKey, request);
            AuthenticationContext ac = prc.getSubcontext(AuthenticationContext.class);
            if ( ac == null ) {
                logger.info("No AuthenticationContext");
                response.sendRedirect(createErrorURL(errorParamIdpExt));
                return;
            }
            MultivaluedMap<String, String> requestParams = new MultivaluedHashMap<>();
            requestParams.putSingle("declRef", resolveDeclarationRef(prc));
            requestParams.putSingle("lang", resolveLanguage(prc));
            requestParams.putSingle("ckey", convKey);
            requestParams.putSingle("sessionId", sessionId);
            String initResult = authenticationHandlerService.initializeSession(requestParams);
            response.getWriter().println(initResult);
            return;
        }
    }

    private void verifyTupasIdentification(HttpServletRequest request, HttpServletResponse response) throws ExternalAuthenticationException, IOException {
        // Identification. Requests have encoding ISO-8859-1
        String queryString = request.getQueryString();
        queryString.replaceAll("%20", "+");
        List<NameValuePair> queryParams = URLEncodedUtils.parse(queryString, Charset.forName("ISO-8859-1"));
        MultivaluedMap<String, String> requestParams = convertToMap(queryParams);
        String sessionId = request.getSession().getId();
        requestParams.putSingle("sessionId", sessionId);

        String token = requestParams.getFirst("token");
        if (StringUtils.isBlank(token)) {
            throw new ExternalAuthenticationException("Bad request, no token");
        }

        TupasIdentification identification = authenticationHandlerService.buildSession(requestParams);
        if ( identification == null ) {
            throw new ExternalAuthenticationException("Authentication verification failed");
        }

        AuthenticationContext ac = ExternalAuthentication.getProfileRequestContext(identification.getCkey(), request).getSubcontext(AuthenticationContext.class);
        if ( ac == null ) {
            logger.warn("Authentication context not valid");
            request.setAttribute(ExternalAuthentication.AUTHENTICATION_ERROR_KEY, "Bad authentication");
            ExternalAuthentication.finishExternalAuthentication(identification.getCkey(), request, response);
            return;
        }

        TupasContext tupasContext = new TupasContext(identification.getName(), identification.getHetu());
        ac.addSubcontext(tupasContext);

        request.setAttribute(ExternalAuthentication.PRINCIPAL_NAME_KEY, token);
        /** External authentication success. Give control back to Shibboleth IdP (4.) */
        ExternalAuthentication.finishExternalAuthentication(identification.getCkey(), request, response);
        return;
    }

    private String createErrorURL(String message) {
        return defaultErrorBase + "?m=" + message;
    }

    /**
     * Resolves authentication context declaration ref from ProfileRequestContext.
     *
     * @param profileRequestContext
     * @return authentication context declaration ref
     */
    private String resolveDeclarationRef(ProfileRequestContext profileRequestContext) throws ExternalAuthenticationException {
        AuthenticationContext ac = profileRequestContext.getSubcontext(AuthenticationContext.class);
        RequestedPrincipalContext rpc = ac.getSubcontext(RequestedPrincipalContext.class);
        List<Principal> principalList = rpc.getRequestedPrincipals();
        if (principalList.size() != 1 || !(principalList.get(0) instanceof AuthnContextDeclRefPrincipal)) {
            throw new ExternalAuthenticationException("Exactly one authentication context declaration ref must be requested");
        } else {
            AuthnContextDeclRefPrincipal declRefPrincipal = (AuthnContextDeclRefPrincipal)principalList.get(0);
            logger.debug("Requested declaration ref: {}", declRefPrincipal.getAuthnContextDeclRef().getAuthnContextDeclRef());
            return declRefPrincipal.getAuthnContextDeclRef().getAuthnContextDeclRef();
        }
    }

    /**
     * Resolves language from a specific SAML extension.
     *
     * @param prc
     * @return language code
     */
    private String resolveLanguage(ProfileRequestContext prc) {
        AuthnRequest message = (AuthnRequest) prc.getInboundMessageContext().getMessage();
        Extensions extensions = message.getExtensions();
        String defaultLang = "fi";
        if (extensions != null) {
            // look for vetuma-style language parameter for backward compatibility
            String vetumaLang = extensions.getOrderedChildren()
                    .stream()
                    .filter(extension -> "kapa".equals(extension.getElementQName().getLocalPart()))
                    .findFirst()
                    .flatMap(vetumaNode -> vetumaNode.getOrderedChildren()
                            .stream()
                            .filter(lgNode -> "lang".equals(lgNode.getElementQName().getLocalPart()))
                            .findFirst())
                    .map(langNode -> langNode.getDOM().getFirstChild().getNodeValue())
                    .orElse(defaultLang);
            logger.debug("Resolved language parameter from authentication request - " + vetumaLang);
            return vetumaLang;
        }
        else {
            logger.debug("Could not find language parameter in authentication request, using default language - " + defaultLang);
            return defaultLang;
        }
    }

    private MultivaluedMap<String,String> convertToMap(List<NameValuePair> queryParams) {
        MultivaluedMap<String,String> paramsMap = new MultivaluedHashMap<>();
        queryParams.stream().forEach(
                nameValuePair -> {
                    paramsMap.add(nameValuePair.getName(), nameValuePair.getValue());
                }
        );
        return paramsMap;
    }


}