package com.github.rawsanj.handler;

import com.github.rawsanj.messaging.RedisChatMessagePublisher;
import com.github.rawsanj.model.ChatMessage;
import com.github.rawsanj.util.ObjectStringConverter;
import lombok.extern.slf4j.Slf4j;
import org.springframework.data.redis.support.atomic.RedisAtomicLong;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.WebSocketMessage;
import org.springframework.web.reactive.socket.WebSocketSession;
import reactor.core.publisher.DirectProcessor;
import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxSink;
import reactor.core.publisher.Mono;

@Slf4j
public class ChatWebSocketHandler implements WebSocketHandler {

	private final DirectProcessor<ChatMessage> messageDirectProcessor;
	private final FluxSink<ChatMessage> chatMessageFluxSink;
	private final RedisChatMessagePublisher redisChatMessagePublisher;
	private final RedisAtomicLong activeUserCounter;

	public ChatWebSocketHandler(DirectProcessor<ChatMessage> messageDirectProcessor, RedisChatMessagePublisher redisChatMessagePublisher, RedisAtomicLong activeUserCounter) {
		this.messageDirectProcessor = messageDirectProcessor;
		this.chatMessageFluxSink = messageDirectProcessor.sink();
		this.redisChatMessagePublisher = redisChatMessagePublisher;
		this.activeUserCounter = activeUserCounter;
	}

	@Override
	public Mono<Void> handle(WebSocketSession webSocketSession) {
		Flux<WebSocketMessage> sendMessageFlux = messageDirectProcessor.flatMap(ObjectStringConverter::objectToString)
			.map(webSocketSession::textMessage)
			.doOnError(throwable -> log.info("Error Occurred while sending message to WebSocket.", throwable));
		Mono<Void> outputMessage = webSocketSession.send(sendMessageFlux);

		Mono<Void> inputMessage = webSocketSession.receive()
			.flatMap(webSocketMessage -> redisChatMessagePublisher.publishChatMessage(webSocketMessage.getPayloadAsText()))
			.doOnSubscribe(subscription -> {
				long activeUserCount = activeUserCounter.incrementAndGet();
				log.debug("User '{}' Connected. Total Active Users: {}", webSocketSession.getId(), activeUserCount);
				chatMessageFluxSink.next(new ChatMessage(0, "CONNECTED", "CONNECTED", activeUserCount));
			})
			.doOnError(throwable -> log.info("Error Occurred while sending message to Redis.", throwable))
			.doFinally(signalType -> {
				long activeUserCount = activeUserCounter.decrementAndGet();
				log.debug("User '{}' Disconnected. Total Active Users: {}", webSocketSession.getId(), activeUserCount);
				chatMessageFluxSink.next(new ChatMessage(0, "DISCONNECTED", "DISCONNECTED", activeUserCount));
			})
			.then();

		return Mono.zip(inputMessage, outputMessage).then();
	}

	public Mono<Void> sendMessage(ChatMessage chatMessage) {
		return Mono.fromSupplier(() -> chatMessageFluxSink.next(chatMessage)).then();
	}

}