/*
 * Copyright 2017 The Hekate Project
 *
 * The Hekate Project licenses this file to you under the Apache License,
 * version 2.0 (the "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at:
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 */

package io.hekate.cluster.seed.multicast;

import io.hekate.cluster.ClusterServiceFactory;
import io.hekate.cluster.seed.SeedNodeProvider;
import io.hekate.core.HekateException;
import io.hekate.core.internal.util.AddressUtils;
import io.hekate.core.internal.util.ConfigCheck;
import io.hekate.core.internal.util.HekateThreadFactory;
import io.hekate.core.internal.util.Utils;
import io.hekate.network.netty.NettyUtils;
import io.hekate.util.async.Waiting;
import io.hekate.util.format.ToString;
import io.hekate.util.format.ToStringIgnore;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOption;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.DatagramChannel;
import io.netty.channel.socket.DatagramPacket;
import io.netty.channel.socket.InternetProtocolFamily;
import io.netty.channel.socket.nio.NioDatagramChannel;
import io.netty.util.concurrent.ScheduledFuture;
import java.io.IOException;
import java.net.Inet6Address;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.NetworkInterface;
import java.net.SocketException;
import java.net.UnknownHostException;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import static java.util.stream.Collectors.toList;

/**
 * Multicast-based seed node provider.
 *
 * <p>
 * This provider uses <a href="https://en.wikipedia.org/wiki/IP_multicast" target="_blank">IP multicasting</a> for seed nodes discovery.
 * When this provider starts it periodically sends multicast messages and awaits for responses from other nodes up to the configurable
 * timeout value.
 * </p>
 *
 * <p>
 * Configuration of this provider is represented by the {@link MulticastSeedNodeProviderConfig} class. Please see its documentation for
 * information about available configuration options.
 * </p>
 *
 * @see ClusterServiceFactory#setSeedNodeProvider(SeedNodeProvider)
 * @see MulticastSeedNodeProviderConfig
 */
public class MulticastSeedNodeProvider implements SeedNodeProvider {
    private enum MessageTYpe {
        DISCOVERY,

        SEED_NODE_INFO
    }

    private static class SeedNode {
        private final InetSocketAddress address;

        private final String cluster;

        public SeedNode(InetSocketAddress address, String cluster) {
            this.address = address;
            this.cluster = cluster;
        }

        public InetSocketAddress address() {
            return address;
        }

        public String cluster() {
            return cluster;
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }

            if (!(o instanceof SeedNode)) {
                return false;
            }

            SeedNode seedNode = (SeedNode)o;

            if (address != null ? !address.equals(seedNode.address) : seedNode.address != null) {
                return false;
            }

            return !(cluster != null ? !cluster.equals(seedNode.cluster) : seedNode.cluster != null);
        }

        @Override
        public int hashCode() {
            int result = address != null ? address.hashCode() : 0;

            result = 31 * result + (cluster != null ? cluster.hashCode() : 0);

            return result;
        }

        @Override
        public String toString() {
            return getClass().getSimpleName() + "[cluster=" + cluster + ", address=" + address + ']';
        }
    }

    private static final Logger log = LoggerFactory.getLogger(MulticastSeedNodeProvider.class);

    private static final boolean DEBUG = log.isDebugEnabled();

    @ToStringIgnore
    private final Object mux = new Object();

    private final InetSocketAddress group;

    private final InternetProtocolFamily ipVer;

    private final int ttl;

    private final long interval;

    private final long waitTime;

    private final boolean loopBackDisabled;

    @ToStringIgnore
    private SeedNode localNode;

    @ToStringIgnore
    private Set<SeedNode> seedNodes;

    @ToStringIgnore
    private NioEventLoopGroup eventLoop;

    @ToStringIgnore
    private DatagramChannel listener;

    @ToStringIgnore
    private DatagramChannel sender;

    @ToStringIgnore
    private ScheduledFuture<?> discoveryFuture;

    /**
     * Constructs new instance with all configuration options set to their default values (see {@link MulticastSeedNodeProviderConfig}).
     *
     * @throws UnknownHostException If failed to resolve multicast group address.
     */
    public MulticastSeedNodeProvider() throws UnknownHostException {
        this(new MulticastSeedNodeProviderConfig());
    }

    /**
     * Constructs new instance.
     *
     * @param cfg Configuration.
     *
     * @throws UnknownHostException If failed to resolve multicast group address.
     */
    public MulticastSeedNodeProvider(MulticastSeedNodeProviderConfig cfg) throws UnknownHostException {
        ConfigCheck check = ConfigCheck.get(getClass());

        check.notNull(cfg, "configuration");
        check.positive(cfg.getPort(), "port");
        check.nonNegative(cfg.getTtl(), "TTL");
        check.notEmpty(cfg.getGroup(), "multicast group");
        check.positive(cfg.getInterval(), "discovery interval");
        check.positive(cfg.getWaitTime(), "wait time");
        check.that(cfg.getInterval() < cfg.getWaitTime(), "discovery interval must be greater than wait time "
            + "[discovery-interval=" + cfg.getInterval() + ", wait-time=" + cfg.getWaitTime() + ']');

        InetAddress groupAddress = InetAddress.getByName(cfg.getGroup());

        check.isTrue(groupAddress.isMulticastAddress(), "address is not a multicast address [address=" + groupAddress + ']');

        group = new InetSocketAddress(groupAddress, cfg.getPort());
        ttl = cfg.getTtl();
        interval = cfg.getInterval();
        waitTime = cfg.getWaitTime();
        loopBackDisabled = cfg.isLoopBackDisabled();

        ipVer = group.getAddress() instanceof Inet6Address ? InternetProtocolFamily.IPv6 : InternetProtocolFamily.IPv4;
    }

    @Override
    public List<InetSocketAddress> findSeedNodes(String cluster) throws HekateException {
        synchronized (mux) {
            if (isRegistered()) {
                return seedNodes.stream().filter(s -> s.cluster().equals(cluster)).map(SeedNode::address).collect(toList());
            }
        }

        return Collections.emptyList();
    }

    @Override
    public void startDiscovery(String cluster, InetSocketAddress address) throws HekateException {
        log.info("Starting seed nodes discovery [cluster={}, {}]", cluster, ToString.formatProperties(this));

        SeedNode thisNode = new SeedNode(address, cluster);

        NetworkInterface nif = selectMulticastInterface(address);

        try {
            synchronized (mux) {
                if (isRegistered()) {
                    throw new IllegalStateException("Multicast seed node provider is already registered with another address "
                        + "[existing=" + localNode + ']');
                }

                ByteBuf discoveryMsg = prepareDiscovery(thisNode);

                ByteBuf seedNodeInfoBytes = prepareSeedNodeInfo(thisNode);

                localNode = thisNode;

                seedNodes = new HashSet<>();

                eventLoop = new NioEventLoopGroup(1, new HekateThreadFactory("SeedNodeMulticast"));

                // Prepare common bootstrap options.
                Bootstrap bootstrap = new Bootstrap();

                bootstrap.option(ChannelOption.SO_REUSEADDR, true);
                bootstrap.option(ChannelOption.IP_MULTICAST_TTL, ttl);
                bootstrap.option(ChannelOption.IP_MULTICAST_IF, nif);

                if (loopBackDisabled) {
                    bootstrap.option(ChannelOption.IP_MULTICAST_LOOP_DISABLED, true);

                    if (DEBUG) {
                        log.debug("Setting {} option to true", ChannelOption.IP_MULTICAST_LOOP_DISABLED);
                    }
                }

                bootstrap.group(eventLoop);
                bootstrap.channelFactory(() -> new NioDatagramChannel(ipVer));

                // Create a sender channel (not joined to a multicast group).
                bootstrap.localAddress(0);
                bootstrap.handler(createSenderHandler(thisNode));

                ChannelFuture senderBind = bootstrap.bind();

                DatagramChannel localSender = (DatagramChannel)senderBind.channel();

                sender = localSender;

                senderBind.get();

                // Create a listener channel and join to a multicast group.
                bootstrap.localAddress(group.getPort());

                bootstrap.handler(createListenerHandler(thisNode, seedNodeInfoBytes));

                ChannelFuture listenerBind = bootstrap.bind();

                listener = (DatagramChannel)listenerBind.channel();

                listenerBind.get();

                log.info("Joining to a multicast group "
                    + "[address={}, port={}, interface={}, ttl={}]", AddressUtils.host(group), group.getPort(), nif.getName(), ttl);

                listener.joinGroup(group, nif).get();

                // Create a periodic task for discovery messages sending.
                discoveryFuture = eventLoop.scheduleWithFixedDelay(() -> {
                    if (DEBUG) {
                        log.debug("Sending discovery message [from={}]", thisNode);
                    }

                    DatagramPacket discovery = new DatagramPacket(discoveryMsg.copy(), group);

                    localSender.writeAndFlush(discovery);
                }, 0, interval, TimeUnit.MILLISECONDS);
            }
        } catch (InterruptedException e) {
            cleanup();

            Thread.currentThread().interrupt();

            throw new HekateException("Thread was interrupted while awaiting for multicast discovery [node=" + thisNode + ']', e);
        } catch (ExecutionException e) {
            cleanup();

            throw new HekateException("Failed to start a multicast seed nodes discovery [node=" + thisNode + ']', e.getCause());
        }

        log.info("Will wait for seed nodes [timeout={}(ms)]", waitTime);

        try {
            Thread.sleep(waitTime);
        } catch (InterruptedException e) {
            log.warn("Thread was interrupted while awaiting for seed nodes discovery.");

            Thread.currentThread().interrupt();
        }

        log.info("Done waiting for seed nodes.");
    }

    @Override
    public void suspendDiscovery() {
        synchronized (mux) {
            if (discoveryFuture != null) {
                if (DEBUG) {
                    log.debug("Canceling discovery task.");
                }

                discoveryFuture.cancel(false);

                discoveryFuture = null;
            }

            if (sender != null && sender.isOpen()) {
                if (DEBUG) {
                    log.debug("Closing multicast sender channel [channel={}]", sender);
                }

                sender.close();

                sender = null;
            }
        }
    }

    @Override
    public void stopDiscovery(String cluster, InetSocketAddress address) throws HekateException {
        log.info("Stopping seed nodes discovery [cluster={}, address={}]", cluster, address);

        cleanup();
    }

    @Override
    public long cleanupInterval() {
        return 0;
    }

    @Override
    public void registerRemote(String cluster, InetSocketAddress node) throws HekateException {
        // No-op.
    }

    @Override
    public void unregisterRemote(String cluster, InetSocketAddress node) throws HekateException {
        // No-op.
    }

    /**
     * Returns the multicast group (see {@link MulticastSeedNodeProviderConfig#setGroup(String)}).
     *
     * @return Multicast group.
     */
    public InetSocketAddress group() {
        return group;
    }

    /**
     * Returns the multicast TTL (see {@link MulticastSeedNodeProviderConfig#setTtl(int)}).
     *
     * @return Multicast group.
     */
    public int ttl() {
        return ttl;
    }

    /**
     * Returns the multicast interval (see {@link MulticastSeedNodeProviderConfig#setInterval(long)}).
     *
     * @return Multicast interval.
     */
    public long interval() {
        return interval;
    }

    /**
     * Returns the time to await for responses from remote nodes (see {@link MulticastSeedNodeProviderConfig#setWaitTime(long)}).
     *
     * @return Time to await for responses from remote nodes.
     */
    public long waitTime() {
        return waitTime;
    }

    private void cleanup() {
        try {
            Waiting eventLoopStop = null;

            synchronized (mux) {
                suspendDiscovery();

                if (listener != null && listener.isOpen()) {
                    if (DEBUG) {
                        log.debug("Closing multicast listener channel [channel={}]", listener);
                    }

                    listener.close();
                }

                if (eventLoop != null) {
                    if (DEBUG) {
                        log.debug("Terminating multicast thread pool...");
                    }

                    eventLoopStop = NettyUtils.shutdown(eventLoop);

                    if (DEBUG) {
                        log.debug("Terminated multicast thread pool.");
                    }
                }
            }

            // Await for thread pool termination out of synchronized context in order to prevent deadlocks.
            if (eventLoopStop != null) {
                eventLoopStop.awaitUninterruptedly();
            }
        } finally {
            synchronized (mux) {
                localNode = null;
                seedNodes = null;
                discoveryFuture = null;
                listener = null;
                eventLoop = null;
            }
        }
    }

    private SimpleChannelInboundHandler<DatagramPacket> createSenderHandler(SeedNode thisNode) {
        return new SimpleChannelInboundHandler<DatagramPacket>() {
            @Override
            public void channelRead0(ChannelHandlerContext ctx, DatagramPacket msg) throws Exception {
                ByteBuf buf = msg.content();

                if (buf.readableBytes() > 4 && buf.readInt() == Utils.MAGIC_BYTES) {
                    MessageTYpe msgType = MessageTYpe.values()[buf.readByte()];

                    if (msgType == MessageTYpe.SEED_NODE_INFO) {
                        String cluster = decodeUtf(buf);

                        InetSocketAddress address = decodeAddress(buf);

                        SeedNode otherNode = new SeedNode(address, cluster);

                        if (!thisNode.equals(otherNode)) {
                            if (DEBUG) {
                                log.debug("Received seed node info message [node={}]", otherNode);
                            }

                            boolean discovered = false;

                            synchronized (mux) {
                                if (isRegistered()) {
                                    if (!seedNodes.contains(otherNode)) {
                                        discovered = true;

                                        seedNodes.add(otherNode);
                                    }
                                }
                            }

                            if (discovered) {
                                log.info("Seed node discovered [address={}]", otherNode.address());
                            }
                        }
                    }
                }
            }
        };
    }

    private SimpleChannelInboundHandler<DatagramPacket> createListenerHandler(SeedNode thisNode, ByteBuf seedNodeInfo) {
        return new SimpleChannelInboundHandler<DatagramPacket>() {
            @Override
            public void channelRead0(ChannelHandlerContext ctx, DatagramPacket msg) throws Exception {
                ByteBuf buf = msg.content();

                if (buf.readableBytes() > 4 && buf.readInt() == Utils.MAGIC_BYTES) {
                    MessageTYpe msgType = MessageTYpe.values()[buf.readByte()];

                    if (msgType == MessageTYpe.DISCOVERY) {
                        String cluster = decodeUtf(buf);
                        InetSocketAddress address = decodeAddress(buf);

                        if (thisNode.cluster().equals(cluster) && !address.equals(thisNode.address())) {
                            onDiscoveryMessage(address);

                            DatagramPacket response = new DatagramPacket(seedNodeInfo.copy(), msg.sender());

                            ctx.writeAndFlush(response);
                        }
                    }
                }
            }
        };
    }

    // Package level for testing purposes.
    void onDiscoveryMessage(InetSocketAddress address) {
        if (DEBUG) {
            log.debug("Received discovery message [from={}]", address);
        }
    }

    private boolean isRegistered() {
        assert Thread.holdsLock(mux) : "Thread must hold lock on mutex.";

        return localNode != null;
    }

    private ByteBuf prepareSeedNodeInfo(SeedNode node) {
        ByteBuf out = Unpooled.buffer();

        out.writeInt(Utils.MAGIC_BYTES);
        out.writeByte(MessageTYpe.SEED_NODE_INFO.ordinal());

        encodeUtf(node.cluster(), out);
        encodeAddress(node.address(), out);

        return out;
    }

    private ByteBuf prepareDiscovery(SeedNode node) {
        ByteBuf out = Unpooled.buffer();

        out.writeInt(Utils.MAGIC_BYTES);
        out.writeByte(MessageTYpe.DISCOVERY.ordinal());

        encodeUtf(node.cluster(), out);
        encodeAddress(node.address(), out);

        return out;
    }

    private void encodeUtf(String str, ByteBuf out) {
        byte[] clusterIdBytes = str.getBytes(StandardCharsets.UTF_8);

        out.writeInt(clusterIdBytes.length);
        out.writeBytes(clusterIdBytes);
    }

    private String decodeUtf(ByteBuf buf) {
        int len = buf.readInt();

        byte[] strBytes = new byte[len];

        buf.readBytes(strBytes);

        return new String(strBytes, StandardCharsets.UTF_8);
    }

    private void encodeAddress(InetSocketAddress socketAddress, ByteBuf out) {
        byte[] addrBytes = socketAddress.getAddress().getAddress();

        out.writeByte(addrBytes.length);
        out.writeBytes(addrBytes);
        out.writeInt(socketAddress.getPort());
    }

    private InetSocketAddress decodeAddress(ByteBuf buf) throws UnknownHostException {
        byte[] addrBytes = new byte[buf.readByte()];

        buf.readBytes(addrBytes);

        int port = buf.readInt();

        return new InetSocketAddress(InetAddress.getByAddress(addrBytes), port);
    }

    // Package level for testing purposes.
    NetworkInterface selectMulticastInterface(InetSocketAddress node) throws HekateException {
        InetAddress address = node.getAddress();

        try {
            if (DEBUG) {
                log.debug("Resolving a network interface [address={}]", address);
            }

            NetworkInterface nif = NetworkInterface.getByInetAddress(address);

            if (nif == null && address.isLoopbackAddress()) {
                if (DEBUG) {
                    log.debug("Failed to resolve a network interface for a loopback address. Will try to find a loopback interface.");
                }

                nif = findLoopbackInterface(address);
            }

            if (nif == null) {
                throw new HekateException("Failed to resolve a network interface by address [address=" + address + ']');
            }

            if (!nif.supportsMulticast()) {
                throw new HekateException("Network interface doesn't support multicasting [name=" + nif.getName()
                    + ", interface-address=" + address + ']');
            }

            return nif;
        } catch (IOException e) {
            throw new HekateException("Failed to resolve multicast network interface [interface-address=" + address + ']', e);
        }
    }

    private NetworkInterface findLoopbackInterface(InetAddress address) throws SocketException, HekateException {
        NetworkInterface lo = null;

        for (NetworkInterface nif : AddressUtils.activeNetworks()) {
            if (nif.isUp() && nif.isLoopback()) {
                for (InetAddress nifAddress : Collections.list(nif.getInetAddresses())) {
                    if (!nifAddress.isLinkLocalAddress() && nifAddress.isLoopbackAddress()) {
                        if (lo != null) {
                            throw new HekateException("Failed to resolve a loopback network interface. "
                                + "Multiple loopback interfaces were detected [address=" + address + ", interface1=" + lo.getName()
                                + ", interface2=" + nif.getName() + ']');
                        }

                        lo = nif;

                        break;
                    }
                }
            }
        }

        return lo;
    }

    @Override
    public String toString() {
        return ToString.format(this);
    }
}