/*
 * Copyright 2012-2013 inBloom, Inc. and its affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.slc.sli.api.security.oauth;

import java.io.IOException;
import java.io.StringWriter;
import java.util.HashMap;
import java.util.Map;
import java.util.regex.Pattern;

import javax.servlet.http.HttpServletResponse;
import javax.ws.rs.core.Response;

import org.apache.commons.lang3.tuple.Pair;
import org.codehaus.jackson.map.ObjectMapper;
import org.slc.sli.api.init.RealmInitializer;
import org.slc.sli.api.jersey.exceptionhandlers.OAuthAccessExceptionHandler;
import org.slc.sli.api.security.OauthSessionManager;
import org.slc.sli.api.security.saml.SamlHelper;
import org.slc.sli.api.util.SecurityUtil;
import org.slc.sli.api.util.SecurityUtil.SecurityTask;
import org.slc.sli.domain.Entity;
import org.slc.sli.domain.NeutralQuery;
import org.slc.sli.domain.Repository;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Scope;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.security.oauth2.common.exceptions.OAuth2Exception;
import org.springframework.stereotype.Controller;
import org.springframework.ui.Model;
import org.springframework.web.bind.annotation.CookieValue;
import org.springframework.web.bind.annotation.ExceptionHandler;
import org.springframework.web.bind.annotation.RequestHeader;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.servlet.ModelAndView;

/**
 * Controller for Discovery Service
 *
 * @author dkornishev
 */
@Controller
@Scope("request")
@RequestMapping("/oauth")
public class AuthController {

    private static final Logger LOG = LoggerFactory.getLogger(AuthController.class);

    private static final Pattern BASIC_AUTH = Pattern.compile("Basic (.+)", Pattern.CASE_INSENSITIVE);

    @Autowired
    private SamlHelper saml;

    @Autowired
    private OauthSessionManager sessionManager;

    @Autowired
    @Value("${sli.sandbox.enabled}")
    private boolean sandboxEnabled;

    @Autowired
    @Qualifier("validationRepo")
    private Repository<Entity> repo;

    /**
     * Calls api to list available realms and injects into model
     *
     * @param model
     *            spring injected model
     * @return name of the template to use
     * @throws IOException
     */
    @RequestMapping(value = "authorize", method = RequestMethod.GET)
    public String listRealms(@RequestParam(value = "redirect_uri", required = false) final String redirectUri,
            @RequestParam(value = "Realm", required = false) final String realmUniqueId,
            @RequestParam(value = "client_id", required = true) final String clientId,
            @RequestParam(value = "state", required = false) final String state,
            @CookieValue(value = "_tla", required = false) final String sessionId, final HttpServletResponse res,
            final Model model) throws IOException {

        if (sessionId != null) {
            String realmId = getRealmId(sessionId);
            if (realmId != null) {
                LOG.debug("found valid session --> user authenticated in realm with id: {}", realmId);
                return ssoInit(realmId, sessionId, redirectUri, clientId, state, res, model);
            } else {
                LOG.debug("session does not map to a valid oauth session");
            }
        }

        Map<String, String> map = getRealmMap(realmUniqueId, sandboxEnabled);

        // Only one realm, so let's bypass the realm selection and direct them straight to that
        // realm
        if (map.size() == 1){
            String realmId = map.keySet().iterator().next();
            return ssoInit( realmId, sessionId, redirectUri, clientId, state, res, model);
        }
        // More than one realm, but we're in sandbox mode so find the SandboxIDP realm
        if(sandboxEnabled) {
            String realmId = null;
            for (Map.Entry<String, String> entry : map.entrySet()) {
                if (entry.getValue().equals(RealmInitializer.ADMIN_REALM_ID)) {
                    realmId = entry.getKey();
                }
            }

            if(realmId==null){
                LOG.error("No Sandbox/Admin Simple-IDP Realm is defined which is required when in sandbox mode");
                throw new IllegalStateException("No Sandbox/Admin Simple-IDP Realm is defined which is required when in sandbox mode");
            }
            return ssoInit( realmId, sessionId, redirectUri, clientId, state, res, model);
        }

        model.addAttribute("redirect_uri", redirectUri != null ? redirectUri : "");
        model.addAttribute("clientId", clientId);
        model.addAttribute("state", state);

        if (sandboxEnabled) {
            for (Map.Entry<String, String> entry : map.entrySet()) {
                if (entry.getValue().equals("SandboxIDP")) {
                    model.addAttribute("sandboxRealm", entry.getKey());
                } else if (entry.getValue().equals(RealmInitializer.ADMIN_REALM_ID)) {
                    model.addAttribute("adminRealm", entry.getKey());
                }
            }
            return "sandboxRealms";
        } else {
            model.addAttribute("dummy", new HashMap<String, String>());
            model.addAttribute("realms", map);
            return "realms";
        }
    }

    private Map<String, String> getRealmMap(final String realmUniqueId, final boolean useUniqueIdentifier) {
        Map<String, String> result = SecurityUtil.runWithAllTenants(new SecurityTask<Map<String, String>>() {

            @Override
            public Map<String, String> execute() {
                return SecurityUtil.sudoRun(new SecurityTask<Map<String, String>>() {
                    @Override
                    public Map<String, String> execute() {
                        Iterable<Entity> realmList = repo.findAll("realm", new NeutralQuery());
                        Map<String, String> map = new HashMap<String, String>();
                        for (Entity realmEntity : realmList) {
                            String name = extractRealmName(useUniqueIdentifier, realmEntity);
                            map.put(realmEntity.getEntityId(), name);

                            // We found the requested realm, so let's only return a map with just
                            // that entry
                            // so that we can short-circuit the realm selection
                            if (realmUniqueId != null && !realmUniqueId.isEmpty()) {
                                if (realmUniqueId.equals(realmEntity.getBody().get("uniqueIdentifier"))) {
                                    map.clear();
                                    map.put(realmEntity.getEntityId(), name);
                                    return map;
                                }
                            }
                        }
                        return map;
                    }

                    private String extractRealmName(final boolean useUniqueIdentifier, Entity realmEntity) {
                        String name;
                        if (useUniqueIdentifier) {
                            name = (String) realmEntity.getBody().get("uniqueIdentifier");
                        } else {
                            name = (String) realmEntity.getBody().get("name");
                        }
                        return name;
                    }
                });
            }
        });
        return result;
    }

    @RequestMapping(value = "token", method = { RequestMethod.POST, RequestMethod.GET })
    public ResponseEntity<String> getAccessToken(@RequestParam("code") String authorizationCode,
            @RequestParam("redirect_uri") String redirectUri,
            @RequestHeader(value = "Authorization", required = false) String authz,
            @RequestParam("client_id") String clientId, @RequestParam("client_secret") String clientSecret, Model model)
            throws BadCredentialsException {
        Map<String, String> parameters = new HashMap<String, String>();
        parameters.put("code", authorizationCode);
        parameters.put("redirect_uri", redirectUri);

        String token;
        try {
            token = this.sessionManager.verify(authorizationCode, Pair.of(clientId, clientSecret));
        } catch (OAuthAccessException e) {
            return handleAccessException(e);
        }

        HttpHeaders headers = new HttpHeaders();
        headers.set("Cache-Control", "no-store");
        headers.setContentType(MediaType.APPLICATION_JSON);

        String response = String.format("{\"access_token\":\"%s\"}", token);

        return new ResponseEntity<String>(response, headers, HttpStatus.OK);
    }

    // Normally we would let the ExceptionHandler for OauthAccessException handle the
    // exception automatically, but since it gets thrown as part of a Spring request handler
    // and not jax-rs, it doesn't get invoked automatically.
    private ResponseEntity<String> handleAccessException(OAuthAccessException e) {
        OAuthAccessExceptionHandler handler = new OAuthAccessExceptionHandler();
        Response resp = handler.toResponse(e);
        ObjectMapper mapper = new ObjectMapper();
        StringWriter writer = new StringWriter();
        try {
            mapper.writeValue(writer, resp.getEntity());
        } catch (Exception e1) {
            LOG.error("Error handling exception", e1);
        }
        return new ResponseEntity<String>(writer.getBuffer().toString(), HttpStatus.valueOf(resp.getStatus()));
    }

    /**
     * Redirects user to the sso init url given valid id
     *
     * @param realmId
     *            id of the realm
     * @return directive to redirect to sso init page
     * @throws IOException
     */
    @RequestMapping(value = "sso", method = { RequestMethod.GET, RequestMethod.POST })
    public String ssoInit(@RequestParam(value = "realmId", required = true) final String realmIndex,
            @RequestParam(value = "sessionId", required = false) final String sessionId,
            @RequestParam(value = "redirect_uri", required = false) String redirectUri,
            @RequestParam(value = "clientId", required = true) final String clientId,
            @RequestParam(value = "state", required = false) final String state, HttpServletResponse res, Model model)
            throws IOException {

        String realmId = getRealmId(sessionId);
        boolean isExpired = isSessionExpired(sessionId);
        boolean forceAuthn = (sessionId != null && realmId != null && !isExpired) ? false : true;

        // Ugly, but we need both sudo access and full tenant access
        Entity realmEnt = SecurityUtil.sudoRun(new SecurityTask<Entity>() {

            @Override
            public Entity execute() {
                return SecurityUtil.runWithAllTenants(new SecurityTask<Entity>() {
                    @Override
                    public Entity execute() {
                        Entity ent = repo.findById("realm", realmIndex);
                        if (ent == null) {
                            throw new IllegalArgumentException("couldn't locate idp for realm: " + realmIndex);
                        }
                        return ent;
                    }
                });
            }
        });

        String tenantId = (String) realmEnt.getBody().get("tenantId");

        @SuppressWarnings("unchecked")
        Map<String, String> idpData = (Map<String, String>) realmEnt.getBody().get("idp");
        String endpoint = idpData.get("redirectEndpoint");
        String idpTypeString = idpData.get("idpType");

        if (endpoint == null) {
            throw new IllegalArgumentException("realm " + realmIndex + " doesn't have an endpoint");
        }

        LOG.debug("creating saml authnrequest with ForceAuthn equal to {}", forceAuthn);

        int idpType = 1;

        if (idpTypeString != null && idpTypeString.equalsIgnoreCase("Siteminder")) {
            idpType = 4;
        }

        // pair contains {messageId,encodedSAML}
        Pair<String, String> tuple = saml.createSamlAuthnRequestForRedirect(endpoint, forceAuthn, idpType);

        sessionManager.createAppSession(sessionId, clientId, redirectUri, state, tenantId, realmEnt.getEntityId(),
                tuple.getLeft(), isExpired);

        LOG.debug("redirecting to: {}", endpoint);

        String redirectUrl = endpoint.contains("?") ? endpoint + "&SAMLRequest=" + tuple.getRight() : endpoint
                + "?SAMLRequest=" + tuple.getRight() + "&RelayState=" + realmIndex;

        return "redirect:" + redirectUrl;
    }

    @SuppressWarnings("unchecked")
    private String getRealmId(final String sessionId) {
        String realmId = null;
        if (sessionId != null) {
            Entity session = sessionManager.getSession(sessionId);
            if (session != null) {
                Map<String, Object> principal = (Map<String, Object>) session.getBody().get("principal");
                realmId = (String) principal.get("realm");
            }
        }

        return realmId;
    }

    private boolean isSessionExpired(final String sessionId) {
        boolean isExpired = true;
        if (sessionId != null) {
            Entity session = sessionManager.getSession(sessionId);
            if (session != null) {
                Map<String, Object> body = session.getBody();
                long expiration = (Long) body.get("expiration");
                long hardLogout = (Long) body.get("hardLogout");
                long now = System.currentTimeMillis();
                isExpired = (now >= expiration || now >= hardLogout);
            }
        }
        return isExpired;
    }

    @ExceptionHandler(OAuth2Exception.class)
    public ModelAndView handleOAuth2Exception(OAuth2Exception e) {
        return handleSandboxExceptions(e);
    }

    private ModelAndView handleSandboxExceptions(Exception e) {
        if(sandboxEnabled) {
            ModelAndView mav = new ModelAndView("error_403");
            mav.addObject("errorMessage", "Custom error message");
            return mav;
        } else {
            ModelAndView mav = new ModelAndView("error_500");
            return mav;
        }
    }

}