/*
 * Copyright 2002-2017 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.github.mthizo247.cloud.netflix.zuul.web.socket;

import com.github.mthizo247.cloud.netflix.zuul.web.proxytarget.ProxyTargetResolver;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.util.Assert;
import org.springframework.util.ErrorHandler;
import org.springframework.util.PatternMatchUtils;
import org.springframework.web.servlet.support.ServletUriComponentsBuilder;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.WebSocketHandlerDecorator;
import org.springframework.web.socket.messaging.WebSocketStompClient;
import org.springframework.web.util.UriComponentsBuilder;

import java.net.URI;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
 * A {@link WebSocketHandlerDecorator} that adds web socket support to zuul reverse proxy.
 *
 * @author Ronald Mthombeni
 * @author Salman Noor
 */
public class ProxyWebSocketHandler extends WebSocketHandlerDecorator {
    private final Logger logger = LoggerFactory.getLogger(ProxyWebSocketHandler.class);
    private final WebSocketHttpHeadersCallback headersCallback;
    private final SimpMessagingTemplate messagingTemplate;
    private final ProxyTargetResolver proxyTargetResolver;
    private final ZuulWebSocketProperties zuulWebSocketProperties;
    private final WebSocketStompClient stompClient;
    private final Map<WebSocketSession, ProxyWebSocketConnectionManager> managers = new ConcurrentHashMap<>();
    private ErrorHandler errorHandler;

    public ProxyWebSocketHandler(WebSocketHandler delegate,
                                 WebSocketStompClient stompClient,
                                 WebSocketHttpHeadersCallback headersCallback,
                                 SimpMessagingTemplate messagingTemplate,
                                 ProxyTargetResolver proxyTargetResolver,
                                 ZuulWebSocketProperties zuulWebSocketProperties) {
        super(delegate);
        this.stompClient = stompClient;
        this.headersCallback = headersCallback;
        this.messagingTemplate = messagingTemplate;
        this.proxyTargetResolver = proxyTargetResolver;
        this.zuulWebSocketProperties = zuulWebSocketProperties;
    }

    public void errorHandler(ErrorHandler errorHandler) {
        this.errorHandler = errorHandler;
    }

    private String getWebSocketServerPath(ZuulWebSocketProperties.WsBrokerage wsBrokerage,
                                          URI uri) {
        String path = uri.toString();
        if (path.contains(":")) {
            path = UriComponentsBuilder.fromUriString(path).build().getPath();
        }

        for (String endPoint : wsBrokerage.getEndPoints()) {
            if (PatternMatchUtils.simpleMatch(toPattern(endPoint), path + "/")) {
                return endPoint;
            }
        }

        return null;
    }

    private ZuulWebSocketProperties.WsBrokerage getWebSocketBrokarage(URI uri) {
        String path = uri.toString();
        if (path.contains(":")) {
            path = UriComponentsBuilder.fromUriString(path).build().getPath();
        }

        for (Map.Entry<String, ZuulWebSocketProperties.WsBrokerage> entry : zuulWebSocketProperties
                .getBrokerages().entrySet()) {
            ZuulWebSocketProperties.WsBrokerage wsBrokerage = entry.getValue();
            if (wsBrokerage.isEnabled()) {
                for (String endPoint : wsBrokerage.getEndPoints()) {
                    if (PatternMatchUtils.simpleMatch(toPattern(endPoint), path + "/")) {
                        return wsBrokerage;
                    }
                }
            }
        }

        return null;
    }

    private String toPattern(String path) {
        path = path.startsWith("/") ? "**" + path : "**/" + path;
        return path.endsWith("/") ? path + "**" : path + "/**";
    }

    @Override
    public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus)
            throws Exception {
        disconnectFromProxiedTarget(session);
        super.afterConnectionClosed(session, closeStatus);
    }

    @Override
    public void handleMessage(WebSocketSession session, WebSocketMessage<?> message)
            throws Exception {
        super.handleMessage(session, message);
        handleMessageFromClient(session, message);
    }

    private void handleMessageFromClient(WebSocketSession session,
                                         WebSocketMessage<?> message) throws Exception {
        boolean handled = false;
        WebSocketMessageAccessor accessor = WebSocketMessageAccessor.create(message);
        if (StompCommand.SEND.toString().equalsIgnoreCase(accessor.getCommand())) {
            handled = true;
            sendMessageToProxiedTarget(session, accessor);
        }

        if (StompCommand.SUBSCRIBE.toString().equalsIgnoreCase(accessor.getCommand())) {
            handled = true;
            subscribeToProxiedTarget(session, accessor);
        }

        if (StompCommand.UNSUBSCRIBE.toString().equalsIgnoreCase(accessor.getCommand())) {
            handled = true;
            unsubscribeFromProxiedTarget(session, accessor);
        }

        if (StompCommand.CONNECT.toString().equalsIgnoreCase(accessor.getCommand())) {
            handled = true;
            connectToProxiedTarget(session);
        }

        if (!handled) {
            if (logger.isDebugEnabled()) {
                logger.debug("STOMP COMMAND " + accessor.getCommand()
                        + " was not explicitly handled");
            }
        }
    }

    private void connectToProxiedTarget(WebSocketSession session) {
        URI sessionUri = session.getUri();
        ZuulWebSocketProperties.WsBrokerage wsBrokerage = getWebSocketBrokarage(
                sessionUri);

        Assert.notNull(wsBrokerage, "wsBrokerage must not be null");

        String path = getWebSocketServerPath(wsBrokerage, sessionUri);
        Assert.notNull(path, "Web socket uri path must be null");

        URI routeTarget = proxyTargetResolver.resolveTarget(wsBrokerage);

        Assert.notNull(routeTarget, "routeTarget must not be null");

        String uri = ServletUriComponentsBuilder
                .fromUri(routeTarget)
                .path(path)
                .replaceQuery(sessionUri.getQuery())
                .toUriString();

        ProxyWebSocketConnectionManager connectionManager = new ProxyWebSocketConnectionManager(
                messagingTemplate, stompClient, session, headersCallback, uri);
        connectionManager.errorHandler(this.errorHandler);
        managers.put(session, connectionManager);
        connectionManager.start();
    }

    private void disconnectFromProxiedTarget(WebSocketSession session) {
        disconnectProxyManager(managers.remove(session));
    }

    private void disconnectProxyManager(ProxyWebSocketConnectionManager proxyManager) {
        if (proxyManager != null) {
            try {
                proxyManager.disconnect();
            } catch (Throwable ignored) {
                // nothing
            }
        }
    }

    private void unsubscribeFromProxiedTarget(WebSocketSession session,
                                              WebSocketMessageAccessor accessor) {
        ProxyWebSocketConnectionManager manager = managers.get(session);
        if (manager != null) {
            manager.unsubscribe(accessor.getDestination());
        }
    }

    private void sendMessageToProxiedTarget(WebSocketSession session,
                                            WebSocketMessageAccessor accessor) {
        ProxyWebSocketConnectionManager manager = managers.get(session);
        manager.sendMessage(accessor.getDestination(), accessor.getPayload());
    }

    private void subscribeToProxiedTarget(WebSocketSession session,
                                          WebSocketMessageAccessor accessor) throws Exception {
        ProxyWebSocketConnectionManager manager = managers.get(session);
        manager.subscribe(accessor.getDestination());
    }
}