package org.javacord.core;

import com.fasterxml.jackson.databind.JsonNode;
import org.apache.logging.log4j.CloseableThreadContext;
import org.apache.logging.log4j.Logger;
import org.javacord.api.AccountType;
import org.javacord.api.DiscordApi;
import org.javacord.api.internal.DiscordApiBuilderDelegate;
import org.javacord.api.listener.GloballyAttachableListener;
import org.javacord.api.util.auth.Authenticator;
import org.javacord.api.util.ratelimit.Ratelimiter;
import org.javacord.core.util.gateway.DiscordWebSocketAdapter;
import org.javacord.core.util.logging.LoggerUtil;
import org.javacord.core.util.logging.PrivacyProtectionLogger;
import org.javacord.core.util.rest.RestEndpoint;
import org.javacord.core.util.rest.RestMethod;
import org.javacord.core.util.rest.RestRequest;
import org.javacord.core.util.rest.RestRequestResult;

import java.net.Proxy;
import java.net.ProxySelector;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Stream;

/**
 * The implementation of {@link DiscordApiBuilderDelegate}.
 */
public class DiscordApiBuilderDelegateImpl implements DiscordApiBuilderDelegate {

    /**
     * The logger of this class.
     */
    private static final Logger logger = LoggerUtil.getLogger(DiscordApiBuilderDelegateImpl.class);

    /**
     * A ratelimiter that is used for global ratelimits.
     */
    private volatile Ratelimiter globalRatelimiter;

    /**
     * The proxy selector which should be used to determine the proxies that should be used to connect to the Discord
     * REST API and websocket.
     */
    private volatile ProxySelector proxySelector;

    /**
     * The proxy which should be used to connect to the Discord REST API and websocket.
     */
    private volatile Proxy proxy;

    /**
     * The authenticator that should be used to authenticate against proxies that require it.
     */
    private volatile Authenticator proxyAuthenticator;

    /**
     * Whether all SSL certificates should be trusted when connecting to the Discord API and websocket.
     */
    private volatile boolean trustAllCertificates = false;

    /**
     * The token which is used to login. Must be present in order to login!
     */
    private volatile String token = null;

    /**
     * The account type of the account with the given token.
     */
    private volatile AccountType accountType = AccountType.BOT;

    /**
     * The current shard starting with <code>0</code>.
     */
    private final AtomicInteger currentShard = new AtomicInteger();

    /**
     * The total amount of shards.
     * If the total amount is <code>1</code>, sharding will be disabled.
     */
    private final AtomicInteger totalShards = new AtomicInteger(1);

    /**
     * A retry attempt counter.
     */
    private final AtomicInteger retryAttempt = new AtomicInteger();

    /**
     * Whether Javacord should wait for all servers to become available on startup or not.
     */
    private volatile boolean waitForServersOnStartup = true;

    /**
     * The globally attachable listeners to register for every created DiscordApi instance.
     */
    private final Map<Class<? extends GloballyAttachableListener>,
            List<GloballyAttachableListener>> listeners
            = new ConcurrentHashMap<>();

    /**
     * Suppliers for globally attachable listeners.
     */
    private final Map<Class<? extends GloballyAttachableListener>,
            List<Supplier<? extends GloballyAttachableListener>>> listenerSuppliers
            = new ConcurrentHashMap<>();

    /**
     * Functions for globally attachable listeners.
     */
    private final Map<Class<? extends GloballyAttachableListener>,
            List<Function<DiscordApi, ? extends GloballyAttachableListener>>> listenerFunctions
            = new ConcurrentHashMap<>();

    /**
     * Globally attachable listeners in need of subtype detection.
     */
    private final List<GloballyAttachableListener> unspecifiedListeners
            = new CopyOnWriteArrayList<>();

    /**
     * Globally attachable listener suppliers in need of subtype detection.
     */
    private final List<Supplier<GloballyAttachableListener>> unspecifiedListenerSuppliers
            = new CopyOnWriteArrayList<>();

    /**
     * Globally attachable listeners in need of subtype detection.
     */
    private final List<Function<DiscordApi,GloballyAttachableListener>> unspecifiedListenerFunctions
            = new CopyOnWriteArrayList<>();

    /**
     * Listener sources for pre-registration, compiled into a single map.
     */
    private volatile Map<Class<? extends GloballyAttachableListener>,
            List<Function<DiscordApi, GloballyAttachableListener>>> preparedListeners;

    /**
     * Unspecified listener sources for pre-registration, compiled into a single map.
     */
    private volatile List<Function<DiscordApi,GloballyAttachableListener>> preparedUnspecifiedListeners;


    @Override
    public CompletableFuture<DiscordApi> login() {
        prepareListeners();
        logger.debug("Creating shard {} of {}", currentShard.get() + 1, totalShards.get());
        CompletableFuture<DiscordApi> future = new CompletableFuture<>();
        if (token == null) {
            future.completeExceptionally(new IllegalArgumentException("You cannot login without a token!"));
            return future;
        }
        try (CloseableThreadContext.Instance closeableThreadContextInstance =
                     CloseableThreadContext.put("shard", Integer.toString(currentShard.get()))) {
            new DiscordApiImpl(accountType, token, currentShard.get(), totalShards.get(), waitForServersOnStartup,
                    globalRatelimiter, proxySelector, proxy, proxyAuthenticator, trustAllCertificates, future, null,
                    preparedListeners, preparedUnspecifiedListeners);
        }
        return future;
    }

    /**
     * Compile pre-registered listeners into proper collections for DiscordApi creation.
     */
    @SuppressWarnings("unchecked")
    private void prepareListeners() {
        if (preparedListeners != null && preparedUnspecifiedListeners != null) {
            // Already created, skip
            return;
        }
        preparedListeners = new ConcurrentHashMap<>();
        Stream<Class<? extends GloballyAttachableListener>> eventTypes = Stream.concat(
                listeners.keySet().stream(),
                Stream.concat(listenerSuppliers.keySet().stream(),
                        listenerFunctions.keySet().stream())
        ).distinct();
        eventTypes.forEach(type -> {
            ArrayList<Function<DiscordApi, GloballyAttachableListener>> typeListenerFunctions = new ArrayList<>();
            listeners.getOrDefault(type, Collections.emptyList()).forEach(
                    listener -> typeListenerFunctions.add(api -> listener)
            );
            listenerSuppliers.getOrDefault(type, Collections.emptyList()).forEach(
                    supplier -> typeListenerFunctions.add(api -> supplier.get())
            );
            listenerFunctions.getOrDefault(type, Collections.emptyList()).forEach(
                    function -> typeListenerFunctions.add((Function<DiscordApi, GloballyAttachableListener>) function)
            );
            preparedListeners.put(type, typeListenerFunctions);
        });
        // Unspecified Listeners
        preparedUnspecifiedListeners = new CopyOnWriteArrayList<>(unspecifiedListenerFunctions);
        unspecifiedListenerSuppliers.forEach(supplier -> preparedUnspecifiedListeners.add((api) -> supplier.get()));
        unspecifiedListeners.forEach(listener -> preparedUnspecifiedListeners.add((api) -> listener));
    }

    @Override
    public Collection<CompletableFuture<DiscordApi>> loginShards(int... shards) {
        Objects.requireNonNull(shards);
        if (shards.length == 0) {
            return Collections.emptyList();
        }
        if (Arrays.stream(shards).distinct().count() != shards.length) {
            throw new IllegalArgumentException("shards cannot be started multiple times!");
        }
        if (Arrays.stream(shards).max().orElseThrow(AssertionError::new) >= getTotalShards()) {
            throw new IllegalArgumentException("shard cannot be greater or equal than totalShards!");
        }
        if (Arrays.stream(shards).min().orElseThrow(AssertionError::new) < 0) {
            throw new IllegalArgumentException("shard cannot be less than 0!");
        }

        if (shards.length == getTotalShards()) {
            logger.info("Creating {} {}", getTotalShards(), (getTotalShards() == 1) ? "shard" : "shards");
        } else {
            logger.info("Creating {} out of {} shards ({})", shards.length, getTotalShards(), shards);
        }

        Collection<CompletableFuture<DiscordApi>> result = new ArrayList<>(shards.length);
        int currentShard = getCurrentShard();
        for (int shard : shards) {
            if (currentShard != 0) {
                CompletableFuture<DiscordApi> future = new CompletableFuture<>();
                future.completeExceptionally(new IllegalArgumentException(
                        "You cannot use loginShards or loginAllShards after setting the current shard!"));
                result.add(future);
                continue;
            }
            setCurrentShard(shard);
            result.add(login());
        }
        setCurrentShard(currentShard);
        return result;
    }

    @Override
    public void setGlobalRatelimiter(Ratelimiter ratelimiter) {
        globalRatelimiter = ratelimiter;
    }

    @Override
    public void setProxySelector(ProxySelector proxySelector) {
        this.proxySelector = proxySelector;
    }

    @Override
    public void setProxy(Proxy proxy) {
        this.proxy = proxy;
    }

    @Override
    public void setProxyAuthenticator(Authenticator authenticator) {
        proxyAuthenticator = authenticator;
    }

    @Override
    public void setTrustAllCertificates(boolean trustAllCertificates) {
        this.trustAllCertificates = trustAllCertificates;
    }

    @Override
    public void setToken(String token) {
        this.token = token;
        PrivacyProtectionLogger.addPrivateData(token);
    }

    @Override
    public Optional<String> getToken() {
        return Optional.ofNullable(token);
    }

    @Override
    public void setAccountType(AccountType type) {
        this.accountType = type;
    }

    @Override
    public AccountType getAccountType() {
        return accountType;
    }

    @Override
    public void setTotalShards(int totalShards) {
        if (currentShard.get() >= totalShards) {
            throw new IllegalArgumentException("currentShard cannot be greater or equal than totalShards!");
        }
        if (totalShards < 1) {
            throw new IllegalArgumentException("totalShards cannot be less than 1!");
        }
        this.totalShards.set(totalShards);
    }

    @Override
    public int getTotalShards() {
        return totalShards.get();
    }

    @Override
    public void setCurrentShard(int currentShard) {
        if (currentShard >= totalShards.get()) {
            throw new IllegalArgumentException("currentShard cannot be greater or equal than totalShards!");
        }
        if (currentShard < 0) {
            throw new IllegalArgumentException("currentShard cannot be less than 0!");
        }
        this.currentShard.set(currentShard);
    }

    @Override
    public int getCurrentShard() {
        return currentShard.get();
    }

    @Override
    public void setWaitForServersOnStartup(boolean waitForServersOnStartup) {
        this.waitForServersOnStartup = waitForServersOnStartup;
    }

    @Override
    public boolean isWaitingForServersOnStartup() {
        return waitForServersOnStartup;
    }

    @Override
    public CompletableFuture<Void> setRecommendedTotalShards() {
        CompletableFuture<Void> future = new CompletableFuture<>();
        if (token == null) {
            future.completeExceptionally(
                    new IllegalArgumentException("You cannot request the recommended total shards without a token!"));
        } else {
            retryAttempt.set(0);
            setRecommendedTotalShards(future);
        }
        return future;
    }

    private void setRecommendedTotalShards(CompletableFuture<Void> future) {
        DiscordApiImpl api = new DiscordApiImpl(
                token, globalRatelimiter, proxySelector, proxy, proxyAuthenticator, trustAllCertificates);
        RestRequest<JsonNode> botGatewayRequest = new RestRequest<>(api, RestMethod.GET, RestEndpoint.GATEWAY_BOT);
        botGatewayRequest
                .execute(RestRequestResult::getJsonBody)
                .thenAccept(resultJson -> {
                    DiscordWebSocketAdapter.setGateway(resultJson.get("url").asText());
                    setTotalShards(resultJson.get("shards").asInt());
                    retryAttempt.set(0);
                    future.complete(null);
                })
                .exceptionally(t -> {
                    int retryDelay = api.getReconnectDelay(retryAttempt.incrementAndGet());
                    logger.info("Retrying to get recommended total shards in {} seconds!", retryDelay);
                    api.getThreadPool().getScheduler().schedule(
                            () -> setRecommendedTotalShards(future), retryDelay, TimeUnit.SECONDS);
                    return null;
                })
                .whenComplete((nothing, throwable) -> api.disconnect());
    }

    @Override
    @SuppressWarnings("unchecked")
    public <T extends GloballyAttachableListener> void addListener(Class<T> listenerClass, T listener) {
        this.listeners.computeIfAbsent(listenerClass, clazz -> new CopyOnWriteArrayList<>());
        List<T> listeners = (List<T>) this.listeners.get(listenerClass);
        if (!listeners.contains(listener)) {
            listeners.add(listener);
        }
    }

    @Override
    public void addListener(GloballyAttachableListener listener) {
        if (!this.unspecifiedListeners.contains(listener)) {
            this.unspecifiedListeners.add(listener);
        }
    }

    @Override
    public <T extends GloballyAttachableListener> void addListener(
                                            Class<T> listenerClass, Supplier<T> listenerSupplier) {
        this.listenerSuppliers.computeIfAbsent(listenerClass, clazz -> new CopyOnWriteArrayList<>());
        List<Supplier<? extends GloballyAttachableListener>> listeners = this.listenerSuppliers.get(listenerClass);
        if (!listeners.contains(listenerSupplier)) {
            listeners.add(listenerSupplier);
        }
    }

    @Override
    public void addListener(Supplier<GloballyAttachableListener> listenerSupplier) {
        if (!this.unspecifiedListenerSuppliers.contains(listenerSupplier))  {
            this.unspecifiedListenerSuppliers.add(listenerSupplier);
        }
    }

    @Override
    public <T extends GloballyAttachableListener> void addListener(
                                                Class<T> listenerClass, Function<DiscordApi, T> listenerFunction) {
        this.listenerFunctions.computeIfAbsent(listenerClass, clazz -> new CopyOnWriteArrayList<>());
        List<Function<DiscordApi, ? extends GloballyAttachableListener>> functions =
                this.listenerFunctions.get(listenerClass);
        if (!functions.contains(listenerFunction)) {
            functions.add(listenerFunction);
        }
    }

    @Override
    public void addListener(Function<DiscordApi, GloballyAttachableListener> listenerFunction) {
        if (!this.unspecifiedListenerFunctions.contains(listenerFunction)) {
            this.unspecifiedListenerFunctions.add(listenerFunction);
        }
    }

    @Override
    public void removeListener(GloballyAttachableListener listener) {
        this.unspecifiedListeners.remove(listener);
    }

    @Override
    public <T extends GloballyAttachableListener> void removeListener(Class<T> listenerClass, T listener) {
        this.listeners.computeIfPresent(listenerClass, (clazz, listeners) -> {
            listeners.remove(listener);
            return listeners.isEmpty() ? null : listeners;
        });
    }

    @Override
    public void removeListenerSupplier(Supplier<GloballyAttachableListener> listenerSupplier) {
        this.unspecifiedListenerSuppliers.remove(listenerSupplier);
    }

    @Override
    public <T extends GloballyAttachableListener> void removeListenerSupplier(
                                                Class<T> listenerClass, Supplier<T> listenerSupplier) {
        this.listenerSuppliers.computeIfPresent(listenerClass, (clazz, suppliers) -> {
            suppliers.remove(listenerSupplier);
            return suppliers.isEmpty() ? null : suppliers;
        });
    }

    @Override
    public void removeListenerFunction(Function<DiscordApi, GloballyAttachableListener> listenerFunction) {
        this.unspecifiedListenerFunctions.remove(listenerFunction);
    }

    @Override
    public <T extends GloballyAttachableListener> void removeListenerFunction(
                                                Class<T> listenerClass, Function<DiscordApi, T> listenerFunction) {
        this.listenerFunctions.computeIfPresent(listenerClass, (clazz, functions) -> {
            functions.remove(listenerFunction);
            return functions.isEmpty() ? null : functions;
        });
    }
}