package tlschannel.async;

import java.io.IOException;
import java.nio.channels.CancelledKeyException;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.InterruptedByTimeoutException;
import java.nio.channels.ReadPendingException;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.ShutdownChannelGroupException;
import java.nio.channels.SocketChannel;
import java.nio.channels.WritePendingException;
import java.util.Iterator;
import java.util.concurrent.CancellationException;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.LongAdder;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer;
import java.util.function.LongConsumer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import tlschannel.NeedsReadException;
import tlschannel.NeedsTaskException;
import tlschannel.NeedsWriteException;
import tlschannel.TlsChannel;
import tlschannel.impl.ByteBufferSet;
import tlschannel.util.Util;

/**
 * This class encapsulates the infrastructure for running {@link AsynchronousTlsChannel}s. Each
 * instance of this class is a singleton-like object that manages a thread pool that makes it
 * possible to run a group of asynchronous channels.
 */
public class AsynchronousTlsChannelGroup {

  private static final Logger logger = LoggerFactory.getLogger(AsynchronousTlsChannelGroup.class);

  /** The main executor of the group has a queue, whose size is a multiple of the number of CPUs. */
  private static final int queueLengthMultiplier = 32;

  private static AtomicInteger globalGroupCount = new AtomicInteger();

  class RegisteredSocket {

    final TlsChannel tlsChannel;
    final SocketChannel socketChannel;

    /**
     * Used to wait until the channel is effectively in the selector (which happens asynchronously
     * to the initial registration.
     */
    final CountDownLatch registered = new CountDownLatch(1);

    SelectionKey key;

    /** Protects {@link #readOperation} reference and instance. */
    final Lock readLock = new ReentrantLock();

    /** Protects {@link #writeOperation} reference and instance. */
    final Lock writeLock = new ReentrantLock();

    /** Current read operation, in not null */
    ReadOperation readOperation;

    /** Current write operation, if not null */
    WriteOperation writeOperation;

    /** Bitwise union of pending operation to be registered in the selector */
    final AtomicInteger pendingOps = new AtomicInteger();

    RegisteredSocket(TlsChannel tlsChannel, SocketChannel socketChannel)
        throws ClosedChannelException {
      this.tlsChannel = tlsChannel;
      this.socketChannel = socketChannel;
    }

    public void close() {
      doCancelRead(this, null);
      doCancelWrite(this, null);
      if (key != null) key.cancel();
      currentRegistrations.getAndDecrement();
      /*
       * Actual de-registration from the selector will happen asynchronously.
       */
      selector.wakeup();
    }
  }

  private abstract static class Operation {
    final ByteBufferSet bufferSet;
    final LongConsumer onSuccess;
    final Consumer<Throwable> onFailure;
    Future<?> timeoutFuture;

    Operation(ByteBufferSet bufferSet, LongConsumer onSuccess, Consumer<Throwable> onFailure) {
      this.bufferSet = bufferSet;
      this.onSuccess = onSuccess;
      this.onFailure = onFailure;
    }
  }

  static final class ReadOperation extends Operation {
    ReadOperation(ByteBufferSet bufferSet, LongConsumer onSuccess, Consumer<Throwable> onFailure) {
      super(bufferSet, onSuccess, onFailure);
    }
  }

  static final class WriteOperation extends Operation {

    /**
     * Because a write operation can flag a block (needs read/write) even after the source buffer
     * was read from, we need to accumulate consumed bytes.
     */
    long consumesBytes = 0;

    WriteOperation(ByteBufferSet bufferSet, LongConsumer onSuccess, Consumer<Throwable> onFailure) {
      super(bufferSet, onSuccess, onFailure);
    }
  }

  private final int id = globalGroupCount.getAndIncrement();

  /**
   * With the intention of being spacer with warnings, use this flag to ensure that we only log the
   * warning about needed task once.
   */
  private final AtomicBoolean loggedTaskWarning = new AtomicBoolean();

  private final Selector selector;

  final ExecutorService executor;

  private final ScheduledThreadPoolExecutor timeoutExecutor =
      new ScheduledThreadPoolExecutor(
          1,
          runnable ->
              new Thread(runnable, String.format("async-channel-group-%d-timeout-thread", id)));

  private final Thread selectorThread =
      new Thread(this::loop, String.format("async-channel-group-%d-selector", id));

  private final ConcurrentLinkedQueue<RegisteredSocket> pendingRegistrations =
      new ConcurrentLinkedQueue<>();

  private enum Shutdown {
    No,
    Wait,
    Immediate
  }

  private volatile Shutdown shutdown = Shutdown.No;

  private final LongAdder selectionCount = new LongAdder();

  private final LongAdder startedReads = new LongAdder();
  private final LongAdder startedWrites = new LongAdder();
  private final LongAdder successfulReads = new LongAdder();
  private final LongAdder successfulWrites = new LongAdder();
  private final LongAdder failedReads = new LongAdder();
  private final LongAdder failedWrites = new LongAdder();
  private final LongAdder cancelledReads = new LongAdder();
  private final LongAdder cancelledWrites = new LongAdder();

  // used for synchronization
  private final AtomicInteger currentRegistrations = new AtomicInteger();

  private final LongAdder currentReads = new LongAdder();
  private final LongAdder currentWrites = new LongAdder();

  /**
   * Creates an instance of this class.
   *
   * @param nThreads number of threads in the executor used to assist the selector loop and run
   *     completion handlers.
   */
  public AsynchronousTlsChannelGroup(int nThreads) {
    try {
      selector = Selector.open();
    } catch (IOException e) {
      throw new RuntimeException(e);
    }
    timeoutExecutor.setRemoveOnCancelPolicy(true);
    this.executor =
        new ThreadPoolExecutor(
            nThreads,
            nThreads,
            0,
            TimeUnit.MILLISECONDS,
            new LinkedBlockingQueue<>(nThreads * queueLengthMultiplier),
            runnable ->
                new Thread(runnable, String.format("async-channel-group-%d-handler-executor", id)),
            new ThreadPoolExecutor.CallerRunsPolicy());
    selectorThread.start();
  }

  /** Creates an instance of this class, using as many thread as available processors. */
  public AsynchronousTlsChannelGroup() {
    this(Runtime.getRuntime().availableProcessors());
  }

  RegisteredSocket registerSocket(TlsChannel reader, SocketChannel socketChannel)
      throws ClosedChannelException {
    if (shutdown != Shutdown.No) {
      throw new ShutdownChannelGroupException();
    }
    RegisteredSocket socket = new RegisteredSocket(reader, socketChannel);
    currentRegistrations.getAndIncrement();
    pendingRegistrations.add(socket);
    selector.wakeup();
    return socket;
  }

  boolean doCancelRead(RegisteredSocket socket, ReadOperation op) {
    socket.readLock.lock();
    try {
      // a null op means cancel any operation
      if (op != null && socket.readOperation == op || op == null && socket.readOperation != null) {
        if (op == null) {
          socket.readOperation.onFailure.accept(new CancellationException());
        }
        socket.readOperation = null;
        cancelledReads.increment();
        currentReads.decrement();
        return true;
      } else {
        return false;
      }
    } finally {
      socket.readLock.unlock();
    }
  }

  boolean doCancelWrite(RegisteredSocket socket, WriteOperation op) {
    socket.writeLock.lock();
    try {
      // a null op means cancel any operation
      if (op != null && socket.writeOperation == op
          || op == null && socket.writeOperation != null) {
        if (op == null) {
          socket.writeOperation.onFailure.accept(new CancellationException());
        }
        socket.writeOperation = null;
        cancelledWrites.increment();
        currentWrites.decrement();
        return true;
      } else {
        return false;
      }
    } finally {
      socket.writeLock.unlock();
    }
  }

  ReadOperation startRead(
      RegisteredSocket socket,
      ByteBufferSet buffer,
      long timeout,
      TimeUnit unit,
      LongConsumer onSuccess,
      Consumer<Throwable> onFailure)
      throws ReadPendingException {
    checkTerminated();
    Util.assertTrue(buffer.hasRemaining());
    waitForSocketRegistration(socket);
    ReadOperation op;
    socket.readLock.lock();
    try {
      if (socket.readOperation != null) {
        throw new ReadPendingException();
      }
      op = new ReadOperation(buffer, onSuccess, onFailure);
      /*
       * we do not try to outsmart the TLS state machine and register for both IO operations for each new socket
       * operation
       */
      socket.pendingOps.set(SelectionKey.OP_WRITE | SelectionKey.OP_READ);
      if (timeout != 0) {
        op.timeoutFuture =
            timeoutExecutor.schedule(
                () -> {
                  boolean success = doCancelRead(socket, op);
                  if (success) {
                    op.onFailure.accept(new InterruptedByTimeoutException());
                  }
                },
                timeout,
                unit);
      }
      socket.readOperation = op;
    } finally {
      socket.readLock.unlock();
    }
    selector.wakeup();
    startedReads.increment();
    currentReads.increment();
    return op;
  }

  WriteOperation startWrite(
      RegisteredSocket socket,
      ByteBufferSet buffer,
      long timeout,
      TimeUnit unit,
      LongConsumer onSuccess,
      Consumer<Throwable> onFailure)
      throws WritePendingException {
    checkTerminated();
    Util.assertTrue(buffer.hasRemaining());
    waitForSocketRegistration(socket);
    WriteOperation op;
    socket.writeLock.lock();
    try {
      if (socket.writeOperation != null) {
        throw new WritePendingException();
      }
      op = new WriteOperation(buffer, onSuccess, onFailure);
      /*
       * we do not try to outsmart the TLS state machine and register for both IO operations for each new socket
       * operation
       */
      socket.pendingOps.set(SelectionKey.OP_WRITE | SelectionKey.OP_READ);
      if (timeout != 0) {
        op.timeoutFuture =
            timeoutExecutor.schedule(
                () -> {
                  boolean success = doCancelWrite(socket, op);
                  if (success) {
                    op.onFailure.accept(new InterruptedByTimeoutException());
                  }
                },
                timeout,
                unit);
      }
      socket.writeOperation = op;
    } finally {
      socket.writeLock.unlock();
    }
    selector.wakeup();
    startedWrites.increment();
    currentWrites.increment();
    return op;
  }

  private void checkTerminated() {
    if (isTerminated()) {
      throw new ShutdownChannelGroupException();
    }
  }

  private void waitForSocketRegistration(RegisteredSocket socket) {
    try {
      socket.registered.await();
    } catch (InterruptedException e) {
      throw new RuntimeException(e);
    }
  }

  private void loop() {
    try {
      while (shutdown == Shutdown.No
          || shutdown == Shutdown.Wait && currentRegistrations.intValue() > 0) {
        int c = selector.select(); // block
        selectionCount.increment();
        // avoid unnecessary creation of iterator object
        if (c > 0) {
          Iterator<SelectionKey> it = selector.selectedKeys().iterator();
          while (it.hasNext()) {
            SelectionKey key = it.next();
            it.remove();
            try {
              key.interestOps(0);
            } catch (CancelledKeyException e) {
              // can happen when channels are closed with pending operations
              continue;
            }
            RegisteredSocket socket = (RegisteredSocket) key.attachment();
            processRead(socket);
            processWrite(socket);
          }
        }
        registerPendingSockets();
        processPendingInterests();
      }
    } catch (Throwable e) {
      logger.error("error in selector loop", e);
    } finally {
      executor.shutdown();
      // use shutdownNow to stop delayed tasks
      timeoutExecutor.shutdownNow();
      if (shutdown == Shutdown.Immediate) {
        for (SelectionKey key : selector.keys()) {
          RegisteredSocket socket = (RegisteredSocket) key.attachment();
          socket.close();
        }
      }
      try {
        selector.close();
      } catch (IOException e) {
        logger.warn("error closing selector: {}", e.getMessage());
      }
    }
  }

  private void processPendingInterests() {
    for (SelectionKey key : selector.keys()) {
      RegisteredSocket socket = (RegisteredSocket) key.attachment();
      int pending = socket.pendingOps.getAndSet(0);
      if (pending != 0) {
        key.interestOps(key.interestOps() | pending);
      }
    }
  }

  private void processWrite(RegisteredSocket socket) {
    socket.writeLock.lock();
    try {
      WriteOperation op = socket.writeOperation;
      if (op != null) {
        executor.execute(
            () -> {
              try {
                doWrite(socket, op);
              } catch (Throwable e) {
                logger.error("error in operation", e);
              }
            });
      }
    } finally {
      socket.writeLock.unlock();
    }
  }

  private void processRead(RegisteredSocket socket) {
    socket.readLock.lock();
    try {
      ReadOperation op = socket.readOperation;
      if (op != null) {
        executor.execute(
            () -> {
              try {
                doRead(socket, op);
              } catch (Throwable e) {
                logger.error("error in operation", e);
              }
            });
      }
    } finally {
      socket.readLock.unlock();
    }
  }

  private void doWrite(RegisteredSocket socket, WriteOperation op) {
    socket.writeLock.lock();
    try {
      if (socket.writeOperation != op) {
        return;
      }
      try {
        long before = op.bufferSet.remaining();
        try {
          writeHandlingTasks(socket, op);
        } finally {
          long c = before - op.bufferSet.remaining();
          Util.assertTrue(c >= 0);
          op.consumesBytes += c;
        }
        socket.writeOperation = null;
        if (op.timeoutFuture != null) {
          op.timeoutFuture.cancel(false);
        }
        op.onSuccess.accept(op.consumesBytes);
        successfulWrites.increment();
        currentWrites.decrement();
      } catch (NeedsReadException e) {
        socket.pendingOps.accumulateAndGet(SelectionKey.OP_READ, (a, b) -> a | b);
        selector.wakeup();
      } catch (NeedsWriteException e) {
        socket.pendingOps.accumulateAndGet(SelectionKey.OP_WRITE, (a, b) -> a | b);
        selector.wakeup();
      } catch (IOException e) {
        if (socket.writeOperation == op) {
          socket.writeOperation = null;
        }
        if (op.timeoutFuture != null) {
          op.timeoutFuture.cancel(false);
        }
        op.onFailure.accept(e);
        failedWrites.increment();
        currentWrites.decrement();
      }
    } finally {
      socket.writeLock.unlock();
    }
  }

  /**
   * Intended use of the channel group is with sockets that run tasks internally, but out of
   * tolerance, run tasks in thread in case the socket does not.
   */
  private void writeHandlingTasks(RegisteredSocket socket, WriteOperation op) throws IOException {
    while (true) {
      try {
        socket.tlsChannel.write(op.bufferSet.array, op.bufferSet.offset, op.bufferSet.length);
        return;
      } catch (NeedsTaskException e) {
        warnAboutNeedTask();
        e.getTask().run();
      }
    }
  }

  private void warnAboutNeedTask() {
    if (!loggedTaskWarning.getAndSet(true)) {
      logger.warn(
          "caught {}; channels used in asynchronous groups should run tasks themselves; "
              + "although task is being dealt with anyway, consider configuring channels properly",
          NeedsTaskException.class.getName());
    }
  }

  private void doRead(RegisteredSocket socket, ReadOperation op) {
    socket.readLock.lock();
    try {
      if (socket.readOperation != op) {
        return;
      }
      try {
        Util.assertTrue(op.bufferSet.hasRemaining());
        long c = readHandlingTasks(socket, op);
        Util.assertTrue(c > 0 || c == -1);
        socket.readOperation = null;
        if (op.timeoutFuture != null) {
          op.timeoutFuture.cancel(false);
        }
        op.onSuccess.accept(c);
        successfulReads.increment();
        currentReads.decrement();
      } catch (NeedsReadException e) {
        socket.pendingOps.accumulateAndGet(SelectionKey.OP_READ, (a, b) -> a | b);
        selector.wakeup();
      } catch (NeedsWriteException e) {
        socket.pendingOps.accumulateAndGet(SelectionKey.OP_WRITE, (a, b) -> a | b);
        selector.wakeup();
      } catch (IOException e) {
        if (socket.readOperation == op) {
          socket.readOperation = null;
        }
        if (op.timeoutFuture != null) {
          op.timeoutFuture.cancel(false);
        }
        op.onFailure.accept(e);
        failedReads.increment();
        currentReads.decrement();
      }
    } finally {
      socket.readLock.unlock();
    }
  }

  /** @see #writeHandlingTasks */
  private long readHandlingTasks(RegisteredSocket socket, ReadOperation op) throws IOException {
    while (true) {
      try {
        return socket.tlsChannel.read(op.bufferSet.array, op.bufferSet.offset, op.bufferSet.length);
      } catch (NeedsTaskException e) {
        warnAboutNeedTask();
        e.getTask().run();
      }
    }
  }

  private void registerPendingSockets() throws ClosedChannelException {
    RegisteredSocket socket;
    while ((socket = pendingRegistrations.poll()) != null) {
      socket.key = socket.socketChannel.register(selector, 0, socket);
      logger.trace("registered key: {}", socket.key);
      socket.registered.countDown();
    }
  }

  /**
   * Whether either {@link #shutdown()} or {@link #shutdownNow()} have been called.
   *
   * @return {@code true} if this group has initiated shutdown and {@code false} if the group is
   *     active
   */
  public boolean isShutdown() {
    return shutdown != Shutdown.No;
  }

  /**
   * Starts the shutdown process. New sockets cannot be registered, already registered one continue
   * operating normally until they are closed.
   */
  public void shutdown() {
    shutdown = Shutdown.Wait;
    selector.wakeup();
  }

  /**
   * Shuts down this channel group immediately. All registered sockets are closed, pending
   * operations may or may not finish.
   */
  public void shutdownNow() {
    shutdown = Shutdown.Immediate;
    selector.wakeup();
  }

  /**
   * Whether this channel group was shut down, and all pending tasks have drained.
   *
   * @return whether the channel is terminated
   */
  public boolean isTerminated() {
    return executor.isTerminated();
  }

  /**
   * Blocks until all registers sockets are closed and pending tasks finished execution after a
   * shutdown request, or the timeout occurs, or the current thread is interrupted, whichever
   * happens first.
   *
   * @param timeout the maximum time to wait
   * @param unit the time unit of the timeout argument
   * @return {@code true} if this group terminated and {@code false} if the group elapsed before
   *     termination
   * @throws InterruptedException if interrupted while waiting
   */
  public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
    return executor.awaitTermination(timeout, unit);
  }

  long getSelectionCount() {
    return selectionCount.longValue();
  }

  /**
   * Return the total number of read operations that were started.
   *
   * @return number of operations
   */
  public long getStartedReadCount() {
    return startedReads.longValue();
  }

  /**
   * Return the total number of write operations that were started.
   *
   * @return number of operations
   */
  public long getStartedWriteCount() {
    return startedWrites.longValue();
  }

  /**
   * Return the total number of read operations that succeeded.
   *
   * @return number of operations
   */
  public long getSuccessfulReadCount() {
    return successfulReads.longValue();
  }

  /**
   * Return the total number of write operations that succeeded.
   *
   * @return number of operations
   */
  public long getSuccessfulWriteCount() {
    return successfulWrites.longValue();
  }

  /**
   * Return the total number of read operations that failed.
   *
   * @return number of operations
   */
  public long getFailedReadCount() {
    return failedReads.longValue();
  }

  /**
   * Return the total number of write operations that failed.
   *
   * @return number of operations
   */
  public long getFailedWriteCount() {
    return failedWrites.longValue();
  }

  /**
   * Return the total number of read operations that were cancelled.
   *
   * @return number of operations
   */
  public long getCancelledReadCount() {
    return cancelledReads.longValue();
  }

  /**
   * Return the total number of write operations that were cancelled.
   *
   * @return number of operations
   */
  public long getCancelledWriteCount() {
    return cancelledWrites.longValue();
  }

  /**
   * Returns the current number of active read operations.
   *
   * @return number of operations
   */
  public long getCurrentReadCount() {
    return currentReads.longValue();
  }

  /**
   * Returns the current number of active write operations.
   *
   * @return number of operations
   */
  public long getCurrentWriteCount() {
    return currentWrites.longValue();
  }

  /**
   * Returns the current number of registered sockets.
   *
   * @return number of sockets
   */
  public long getCurrentRegistrationCount() {
    return currentRegistrations.longValue();
  }
}