package de.rwth.idsg.steve.ocpp.soap;

import lombok.extern.slf4j.Slf4j;
import org.apache.cxf.Bus;
import org.apache.cxf.binding.soap.SoapMessage;
import org.apache.cxf.binding.soap.SoapVersion;
import org.apache.cxf.binding.soap.SoapVersionFactory;
import org.apache.cxf.endpoint.Server;
import org.apache.cxf.endpoint.ServerRegistry;
import org.apache.cxf.interceptor.InterceptorChain;
import org.apache.cxf.interceptor.StaxInInterceptor;
import org.apache.cxf.message.Message;
import org.apache.cxf.phase.AbstractPhaseInterceptor;
import org.apache.cxf.phase.Phase;
import org.apache.cxf.service.model.EndpointInfo;
import org.apache.cxf.staxutils.DepthXMLStreamReader;
import org.apache.cxf.staxutils.StaxUtils;
import org.apache.cxf.transport.MessageObserver;

import javax.xml.stream.XMLInputFactory;
import javax.xml.stream.XMLStreamConstants;
import javax.xml.stream.XMLStreamException;
import javax.xml.stream.XMLStreamReader;
import java.io.BufferedInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static de.rwth.idsg.steve.SteveConfiguration.CONFIG;

/**
 * Taken from http://cxf.apache.org/docs/service-routing.html and modified.
 *
 */
@Slf4j
public class MediatorInInterceptor extends AbstractPhaseInterceptor<SoapMessage> {

    private final XMLInputFactory xmlInputFactory = XMLInputFactory.newInstance();
    private final Map<String, Server> actualServers = new HashMap<>(2);

    public MediatorInInterceptor() {
        super(Phase.POST_STREAM);
        super.addBefore(StaxInInterceptor.class.getName());
    }

    public final void handleMessage(SoapMessage message) {
        String schemaNamespace = "";
        InterceptorChain chain = message.getInterceptorChain();

        // Scan the incoming message for its schema namespace
        try {
            // Create a buffered stream so that we get back the original stream after scanning
            InputStream is = message.getContent(InputStream.class);
            BufferedInputStream bis = new BufferedInputStream(is);
            bis.mark(bis.available());
            message.setContent(InputStream.class, bis);

            String encoding = (String) message.get(Message.ENCODING);
            XMLStreamReader reader = xmlInputFactory.createXMLStreamReader(bis, encoding);
            DepthXMLStreamReader xmlReader = new DepthXMLStreamReader(reader);

            if (xmlReader.nextTag() == XMLStreamConstants.START_ELEMENT) {
                String ns = xmlReader.getNamespaceURI();
                SoapVersion soapVersion = SoapVersionFactory.getInstance().getSoapVersion(ns);
                // Advance just past header
                StaxUtils.toNextTag(xmlReader, soapVersion.getBody());
                // Past body
                xmlReader.nextTag();
            }
            schemaNamespace = xmlReader.getName().getNamespaceURI();
            bis.reset();

        } catch (IOException | XMLStreamException ex) {
            log.error("Exception happened", ex);
        }

        // Init the lookup, when the first message ever arrives
        if (actualServers.isEmpty()) {
            initServerLookupMap(message);
        }

        // We redirect the message to the actual OCPP service
        Server targetServer = actualServers.get(schemaNamespace);

        // Redirect the request
        if (targetServer != null) {
            MessageObserver mo = targetServer.getDestination().getMessageObserver();
            mo.onMessage(message);
        }

        // Now the response has been put in the message, abort the chain
        chain.abort();
    }

    /**
     * Iterate over all available servers registered on the bus and build a map
     * consisting of (namespace, server) pairs for later lookup, so we can
     * redirect to the version-specific implementation according to the namespace
     * of the incoming message.
     */
    private void initServerLookupMap(SoapMessage message) {
        Bus bus = message.getExchange().getBus();

        ServerRegistry serverRegistry = bus.getExtension(ServerRegistry.class);
        if (serverRegistry == null) {
            return;
        }

        List<Server> temp = serverRegistry.getServers();
        for (Server server : temp) {
            EndpointInfo info = server.getEndpoint().getEndpointInfo();
            String address = info.getAddress();

            // exclude the 'dummy' routing server
            if (CONFIG.getRouterEndpointPath().equals(address)) {
                continue;
            }

            String serverNamespace = info.getName().getNamespaceURI();
            actualServers.put(serverNamespace, server);
        }
    }
}