package io.leangen.graphql.spqr.spring.web.apollo;

import graphql.ExecutionResult;
import graphql.GraphQL;
import io.leangen.graphql.spqr.spring.web.servlet.websocket.GraphQLWebSocketExecutor;
import io.leangen.graphql.spqr.spring.web.dto.GraphQLRequest;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler;

import java.io.IOException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.atomic.AtomicReference;

import static io.leangen.graphql.spqr.spring.web.apollo.ApolloMessage.GQL_CONNECTION_INIT;
import static io.leangen.graphql.spqr.spring.web.apollo.ApolloMessage.GQL_CONNECTION_TERMINATE;
import static io.leangen.graphql.spqr.spring.web.apollo.ApolloMessage.GQL_START;
import static io.leangen.graphql.spqr.spring.web.apollo.ApolloMessage.GQL_STOP;

class ApolloProtocolHandler extends TextWebSocketHandler {

    private final GraphQL graphQL;
    private final GraphQLWebSocketExecutor executor;
    private final TaskScheduler taskScheduler;
    private final int keepAliveInterval;
    private final Map<String, Subscription> subscriptions = new ConcurrentHashMap<>();
    private final AtomicReference<ScheduledFuture<?>> keepAlive = new AtomicReference<>();

    private static final Logger log = LoggerFactory.getLogger(ApolloProtocolHandler.class);

    public ApolloProtocolHandler(GraphQL graphQL, GraphQLWebSocketExecutor executor,
                                 TaskScheduler taskScheduler, int keepAliveInterval) {
        this.graphQL = graphQL;
        this.executor = executor;
        this.taskScheduler = taskScheduler;
        this.keepAliveInterval = keepAliveInterval;
    }

    @Override
    public void afterConnectionEstablished(WebSocketSession session) throws Exception {
        super.afterConnectionEstablished(session);
        if (taskScheduler != null) {
            this.keepAlive.compareAndSet(null, taskScheduler.scheduleWithFixedDelay(keepAliveTask(session), Math.max(keepAliveInterval, 1000)));
        }
    }

    @Override
    public void afterConnectionClosed(WebSocketSession session, CloseStatus status) {
        cancelAll();
        if (taskScheduler != null) {
            this.keepAlive.getAndUpdate(task -> {
                if (task != null) {
                    task.cancel(false);
                }
                return null;
            });
        }
    }

    @Override
    public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
        session.close(CloseStatus.SERVER_ERROR);
        cancelAll();
    }

    @Override
    protected void handleTextMessage(WebSocketSession session, TextMessage message) {
        try {
            ApolloMessage apolloMessage;
            try {
                apolloMessage = ApolloMessage.from(message);
            } catch (IOException e) {
                session.sendMessage(ApolloMessage.connectionError());
                return;
            }
            switch (apolloMessage.getType()) {
                case GQL_CONNECTION_INIT:
                    session.sendMessage(ApolloMessage.connectionAck());
                    if (taskScheduler != null) {
                        session.sendMessage(ApolloMessage.keepAlive());
                    }
                    break;
                case GQL_START:
                    GraphQLRequest request = ((StartMessage) apolloMessage).getPayload();
                    ExecutionResult result = executor.execute(graphQL, request, session);
                    if (result.getData() instanceof Publisher) {
                        handleSubscription(apolloMessage.getId(), result, session);
                    } else {
                        handleQueryOrMutation(apolloMessage.getId(), result, session);
                    }
                    break;
                case GQL_STOP:
                    Subscription toStop = subscriptions.get(apolloMessage.getId());
                    if (toStop != null) {
                        toStop.cancel();
                        subscriptions.remove(apolloMessage.getId(), toStop);
                    }
                    break;
                case GQL_CONNECTION_TERMINATE:
                    session.close();
                    cancelAll();
                    break;
            }
        } catch (IOException e) {
            fatalError(session, e);
        }
    }

    private void handleQueryOrMutation(String id, ExecutionResult result, WebSocketSession session) {
        try {
            session.sendMessage(ApolloMessage.data(id, result));
            session.sendMessage(ApolloMessage.complete(id));
        } catch (IOException e) {
            fatalError(session, e);
        }
    }

    private void handleSubscription(String id, ExecutionResult result, WebSocketSession session) {
        Publisher<ExecutionResult> stream = result.getData();
        Subscriber<ExecutionResult> subscriber = new Subscriber<ExecutionResult>() {
            private Subscription subscription;

            @Override
            public void onSubscribe(Subscription subscription) {
                this.subscription = subscription;
                subscriptions.put(id, subscription);
                request(1);
            }

            @Override
            public void onNext(ExecutionResult executionResult) {
                try {
                    if (executionResult.getErrors().isEmpty()) {
                        session.sendMessage(ApolloMessage.data(id, executionResult));
                    } else {
                        session.sendMessage(ApolloMessage.error(id, executionResult.getErrors()));
                    }
                } catch (IOException e) {
                    fatalError(session, e);
                }
                request(1);
            }

            @Override
            public void onError(Throwable t) {
                try {
                    session.sendMessage(ApolloMessage.error(id, t));
                } catch (IOException e) {
                    fatalError(session, e);
                }
            }

            @Override
            public void onComplete() {
                try {
                    session.sendMessage(ApolloMessage.complete(id));
                } catch (IOException e) {
                    fatalError(session, e);
                }
            }

            private void request(int n) {
                Subscription subscription = this.subscription;
                if (subscription != null) {
                    subscription.request(n);
                }
            }
        };
        stream.subscribe(subscriber);
    }

    void cancelAll() {
        synchronized (subscriptions) {
            subscriptions.values().forEach(Subscription::cancel);
            subscriptions.clear();
        }
    }

    private void fatalError(WebSocketSession session, Exception exception) {
        try {
            session.close(CloseStatus.SESSION_NOT_RELIABLE);
        } catch (Exception ignored) {/*no-op*/}
        cancelAll();
        log.warn(String.format("WebSocket session %s (%s) closed due to an exception", session.getId(), session.getRemoteAddress()), exception);
    }

    private Runnable keepAliveTask(WebSocketSession session) {
        return () -> {
            try {
                if (session != null && session.isOpen()) {
                    session.sendMessage(ApolloMessage.keepAlive());
                }
            } catch (Exception exception) {
                try {
                    session.close(CloseStatus.SESSION_NOT_RELIABLE);
                } catch (Exception ignored) {/*no-op*/}
            }
        };
    }
}