/*
 * Copyright 2002-2017 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.github.mthizo247.cloud.netflix.zuul.web.socket;

import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompHeaders;
import org.springframework.messaging.simp.stomp.StompSession;
import org.springframework.messaging.simp.stomp.StompSessionHandler;
import org.springframework.util.ErrorHandler;
import org.springframework.web.socket.WebSocketHttpHeaders;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.client.ConnectionManagerSupport;
import org.springframework.web.socket.messaging.WebSocketStompClient;

import java.lang.reflect.Type;
import java.security.Principal;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

/**
 * A web socket connection manager bridge between client and backend server via zuul
 * reverse proxy
 *
 * @author Ronald Mthombeni
 * @author Salman Noor
 */
public class ProxyWebSocketConnectionManager extends ConnectionManagerSupport
        implements StompSessionHandler {
    private final WebSocketStompClient stompClient;
    private final WebSocketSession userAgentSession;
    private final WebSocketHttpHeadersCallback httpHeadersCallback;
    private StompSession serverSession;
    private Map<String, StompSession.Subscription> subscriptions = new ConcurrentHashMap<>();
    private ErrorHandler errorHandler;
    private SimpMessagingTemplate messagingTemplate;

    public ProxyWebSocketConnectionManager(SimpMessagingTemplate messagingTemplate,
                                           WebSocketStompClient stompClient, WebSocketSession userAgentSession,
                                           WebSocketHttpHeadersCallback httpHeadersCallback, String uri) {
        super(uri);
        this.messagingTemplate = messagingTemplate;
        this.stompClient = stompClient;
        this.userAgentSession = userAgentSession;
        this.httpHeadersCallback = httpHeadersCallback;
    }

    public void errorHandler(ErrorHandler errorHandler) {
        this.errorHandler = errorHandler;
    }

    private WebSocketHttpHeaders buildWebSocketHttpHeaders() {
        WebSocketHttpHeaders wsHeaders = new WebSocketHttpHeaders();
        if (httpHeadersCallback != null) {
            httpHeadersCallback.applyHeaders(userAgentSession, wsHeaders);
        }
        return wsHeaders;
    }

    @Override
    protected void openConnection() {
        connect();
    }

    public void connect() {
        try {
            serverSession = stompClient
                    .connect(getUri().toString(), buildWebSocketHttpHeaders(), this)
                    .get();
        } catch (Exception e) {
            logger.error("Error connecting to web socket uri " + getUri(), e);
            throw new RuntimeException(e);
        }
    }

    public void reconnect(final long delay) {
        if (delay > 0) {
            logger.warn("Connection lost or refused, will attempt to reconnect after "
                    + delay + " millis");
            try {
                Thread.sleep(delay);
            } catch (InterruptedException e) {
                //
            }
        }

        Set<String> destinations = new HashSet<>(subscriptions.keySet());

        connect();

        for (String destination : destinations) {
            try {
                subscribe(destination);
            } catch (Exception ignored) {
                // nothing
            }
        }
    }

    @Override
    protected void closeConnection() throws Exception {
        if (isConnected()) {
            this.serverSession.disconnect();
        }
    }

    @Override
    protected boolean isConnected() {
        return (this.serverSession != null && this.serverSession.isConnected());
    }

    @Override
    public void afterConnected(StompSession session, StompHeaders connectedHeaders) {
        if (logger.isDebugEnabled()) {
            logger.debug("Proxied target now connected " + session);
        }
    }

    @Override
    public void handleException(StompSession session, StompCommand command,
                                StompHeaders headers, byte[] payload, Throwable ex) {
        if (errorHandler != null) {
            errorHandler.handleError(new ProxySessionException(this, session, ex));
        }
    }

    @Override
    public void handleTransportError(StompSession session, Throwable ex) {
        if (errorHandler != null) {
            errorHandler.handleError(new ProxySessionException(this, session, ex));
        }
    }

    @Override
    public Type getPayloadType(StompHeaders headers) {
        return Object.class;
    }

    public void sendMessage(final String destination, final Object msg) {
        if (msg instanceof String) { // in case of a json string to avoid double
            // converstion by the converters
            serverSession.send(destination, ((String) msg).getBytes());
            return;
        }

        serverSession.send(destination, msg);
    }

    @Override
    public void handleFrame(StompHeaders headers, Object payload) {
        if (headers.getDestination() != null) {
            String destination = headers.getDestination();
            if (logger.isDebugEnabled()) {
                logger.debug("Received " + payload + ", To " + headers.getDestination());
            }

            Principal principal = userAgentSession.getPrincipal();
            String userDestinationPrefix = messagingTemplate.getUserDestinationPrefix();
            if (principal != null && destination.startsWith(userDestinationPrefix)) {
                destination = destination.substring(userDestinationPrefix.length());

                destination = destination.startsWith("/") ? destination
                        : "/" + destination;

                messagingTemplate.convertAndSendToUser(principal.getName(), destination,
                        payload, copyHeaders(headers.toSingleValueMap()));
            } else {
                messagingTemplate.convertAndSend(destination, payload,
                        copyHeaders(headers.toSingleValueMap()));
            }
        }
    }

    private Map<String, Object> copyHeaders(Map<String, String> original) {
        Map<String, Object> copy = new HashMap<>();
        for (String key : original.keySet()) {
            copy.put(key, original.get(key));
        }

        return copy;
    }

    private void connectIfNecessary() {
        if (!isConnected()) {
            connect();
        }
    }

    public void subscribe(String destination) throws Exception {
        connectIfNecessary();
        StompSession.Subscription subscription = serverSession.subscribe(destination,
                this);
        subscriptions.put(destination, subscription);
    }

    public void unsubscribe(String destination) {
        StompSession.Subscription subscription = subscriptions.remove(destination);
        if (subscription != null) {
            connectIfNecessary();
            subscription.unsubscribe();
        }
    }

    public boolean isConnectedToUserAgent() {
        return (userAgentSession != null && userAgentSession.isOpen());
    }

    public void disconnect() {
        try {
            closeConnection();
        } catch (Exception e) {
            // nothing
        }
    }
}