package yushijinhun.authlibagent.web.yggdrasil.servlet;

import java.io.IOException;
import java.util.Map;
import java.util.concurrent.CompletionException;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.json.JSONObject;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.transaction.support.TransactionTemplate;
import org.springframework.web.context.support.SpringBeanAutowiringSupport;
import yushijinhun.authlibagent.service.YggdrasilService;
import yushijinhun.authlibagent.web.yggdrasil.ResponseSerializer;

abstract public class YggdrasilServlet extends HttpServlet {

	private static final long serialVersionUID = 1L;

	protected final Logger LOGGER = LogManager.getFormatterLogger(getClass());

	@Autowired
	protected YggdrasilService backend;

	@Autowired
	protected ResponseSerializer serializer;

	@Autowired
	private TransactionTemplate transactionTemplate;

	@Value("#{errorNames}")
	private Map<String, String> errorNames;

	@Value("#{errorCodes}")
	private Map<String, Integer> errorCodes;

	@Value("#{config['security.showErrorCause']}")
	private boolean showErrorCause;

	@Override
	public void init() throws ServletException {
		SpringBeanAutowiringSupport.processInjectionBasedOnServletContext(this, getServletContext());
	}

	protected void handleRequest(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException {
		int respCode;
		Object jsonResp;
		try {
			jsonResp = wrapProcess(req);
			if (jsonResp == null) {
				respCode = 204;
			} else {
				respCode = 200;
			}
		} catch (IOException | ServletException e) {
			throw e;
		} catch (Throwable e) {
			log("exception during process request", e);

			respCode = getConfiguredErrorCode(e.getClass());
			if (respCode == -1) {
				respCode = 500;
				LOGGER.warn("unexcept exception", e);
			}

			String errorName = lookupErrorName(e);
			String message = e.getMessage();
			String cause = showErrorCause ? lookupErrorName(e.getCause()) : null;

			JSONObject errJson = new JSONObject();
			errJson.put("error", errorName);
			if (message != null)
				errJson.put("errorMessage", message);
			if (cause != null)
				errJson.put("cause", cause);
			jsonResp = errJson;
		}

		resp.setStatus(respCode);
		if (jsonResp != null) {
			resp.setContentType("application/json; charset=utf-8");
			resp.getWriter().print(jsonResp);
		}
	}

	private Object wrapProcess(HttpServletRequest req) throws Throwable {
		try {
			return transactionTemplate.execute(dummy -> {
				try {
					return process(req);
				} catch (Throwable e) {
					throw new CompletionException(e);
				}
			});
		} catch (CompletionException e) {
			throw e.getCause();
		}
	}

	abstract protected Object process(HttpServletRequest req) throws Exception;

	private String lookupErrorName(Throwable e) {
		if (e == null) {
			return null;
		}
		String exName = getConfiguredErrorName(e.getClass());
		if (exName == null) {
			exName = e.getClass().getSimpleName();
		}
		return exName;
	}

	private String getConfiguredErrorName(Class<?> clazz) {
		for (Class<?> currentClass = clazz; currentClass != null; currentClass = currentClass.getSuperclass()) {
			String exName = errorNames.get(currentClass.getCanonicalName());
			if (exName != null) {
				return exName;
			}
		}
		return null;
	}

	private int getConfiguredErrorCode(Class<?> clazz) {
		for (Class<?> currentClass = clazz; currentClass != null; currentClass = currentClass.getSuperclass()) {
			Integer code = errorCodes.get(currentClass.getCanonicalName());
			if (code != null) {
				return code;
			}
		}
		return -1;
	}

}