package tlschannel.impl;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ByteChannel;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.WritableByteChannel;
import java.util.Optional;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
import javax.net.ssl.SSLEngineResult.Status;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLSession;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import tlschannel.*;
import tlschannel.TlsChannelCallbackException;
import tlschannel.util.Util;

public class TlsChannelImpl implements ByteChannel {

  private static final Logger logger = LoggerFactory.getLogger(TlsChannelImpl.class);

  public static final int buffersInitialSize = 4096;

  /** Official TLS max data size is 2^14 = 16k. Use 1024 more to account for the overhead */
  public static final int maxTlsPacketSize = 17 * 1024;

  /** Used to signal EOF conditions from the underlying channel */
  public static class EofException extends Exception {
    private static final long serialVersionUID = -3859156713994602991L;

    /** For efficiency, override this method to do nothing. */
    @Override
    public Throwable fillInStackTrace() {
      return this;
    }
  }

  private final ReadableByteChannel readChannel;
  private final WritableByteChannel writeChannel;
  private final SSLEngine engine;
  private BufferHolder inEncrypted;
  private final Consumer<SSLSession> initSessionCallback;

  private final boolean runTasks;
  private final TrackingAllocator encryptedBufAllocator;
  private final TrackingAllocator plainBufAllocator;
  private final boolean waitForCloseConfirmation;

  // @formatter:off
  public TlsChannelImpl(
      ReadableByteChannel readChannel,
      WritableByteChannel writeChannel,
      SSLEngine engine,
      Optional<BufferHolder> inEncrypted,
      Consumer<SSLSession> initSessionCallback,
      boolean runTasks,
      TrackingAllocator plainBufAllocator,
      TrackingAllocator encryptedBufAllocator,
      boolean releaseBuffers,
      boolean waitForCloseConfirmation) {
    // @formatter:on
    this.readChannel = readChannel;
    this.writeChannel = writeChannel;
    this.engine = engine;
    this.inEncrypted =
        inEncrypted.orElseGet(
            () ->
                new BufferHolder(
                    "inEncrypted",
                    Optional.empty(),
                    encryptedBufAllocator,
                    buffersInitialSize,
                    maxTlsPacketSize,
                    false /* plainData */,
                    releaseBuffers));
    this.initSessionCallback = initSessionCallback;
    this.runTasks = runTasks;
    this.plainBufAllocator = plainBufAllocator;
    this.encryptedBufAllocator = encryptedBufAllocator;
    this.waitForCloseConfirmation = waitForCloseConfirmation;
    inPlain =
        new BufferHolder(
            "inPlain",
            Optional.empty(),
            plainBufAllocator,
            buffersInitialSize,
            maxTlsPacketSize,
            true /* plainData */,
            releaseBuffers);
    outEncrypted =
        new BufferHolder(
            "outEncrypted",
            Optional.empty(),
            encryptedBufAllocator,
            buffersInitialSize,
            maxTlsPacketSize,
            false /* plainData */,
            releaseBuffers);
  }

  private final Lock initLock = new ReentrantLock();
  private final Lock readLock = new ReentrantLock();
  private final Lock writeLock = new ReentrantLock();

  private volatile boolean negotiated = false;

  /**
   * Whether a IOException was received from the underlying channel or from the {@link SSLEngine}.
   */
  private volatile boolean invalid = false;

  /** Whether a close_notify was already sent. */
  private volatile boolean shutdownSent = false;

  /** Whether a close_notify was already received. */
  private volatile boolean shutdownReceived = false;

  /** Decrypted data from inEncrypted */
  private BufferHolder inPlain;

  /** Contains data encrypted to send to the underlying channel */
  private BufferHolder outEncrypted;

  /**
   * Reference to the current read buffer supplied by the client this field is only valid during a
   * read operation. This field is used instead of {@link #inPlain} in order to avoid copying
   * returned bytes when possible.
   */
  private ByteBufferSet suppliedInPlain;

  /** Bytes produced by the current read operation */
  private int bytesToReturn;

  /**
   * Handshake wrap() method calls need a buffer to read from, even when they actually do not read
   * anything.
   *
   * <p>Note: standard SSLEngine is happy with no buffers, the empty buffer is here to make this
   * work with Netty's OpenSSL's wrapper.
   */
  private final ByteBufferSet dummyOut =
      new ByteBufferSet(new ByteBuffer[] {ByteBuffer.allocate(0)});

  public Consumer<SSLSession> getSessionInitCallback() {
    return initSessionCallback;
  }

  public TrackingAllocator getPlainBufferAllocator() {
    return plainBufAllocator;
  }

  public TrackingAllocator getEncryptedBufferAllocator() {
    return encryptedBufAllocator;
  }

  // read

  public long read(ByteBufferSet dest) throws IOException, NeedsTaskException {
    checkReadBuffer(dest);
    if (!dest.hasRemaining()) return 0;
    handshake();
    readLock.lock();
    try {
      if (invalid || shutdownSent) {
        throw new ClosedChannelException();
      }

      long originalDestPosition = dest.position();
      suppliedInPlain = dest;
      bytesToReturn = inPlain.nullOrEmpty() ? 0 : inPlain.buffer.position();

      while (true) {

        // return bytes are soon as we have them
        if (bytesToReturn > 0) {
          if (inPlain.nullOrEmpty()) {
            // if there is not in internal buffer, that means that the bytes must be in the supplied
            // buffer
            Util.assertTrue(dest.position() == originalDestPosition + bytesToReturn);
            return bytesToReturn;
          } else {
            Util.assertTrue(inPlain.buffer.position() == bytesToReturn);
            return transferPendingPlain(dest);
          }
        }

        if (shutdownReceived) {
          return -1;
        }
        Util.assertTrue(inPlain.nullOrEmpty());
        switch (engine.getHandshakeStatus()) {
          case NEED_UNWRAP:
          case NEED_WRAP:
            writeAndHandshake();
            break;
          case NOT_HANDSHAKING:
          case FINISHED:
            readAndUnwrap();
            if (shutdownReceived) {
              return -1;
            }
            break;
          case NEED_TASK:
            handleTask();
            break;
          default:
            // Unsupported stage eg: NEED_UNWRAP_AGAIN
            return -1;
        }
      }
    } catch (EofException e) {
      return -1;
    } finally {
      bytesToReturn = 0;
      suppliedInPlain = null;
      readLock.unlock();
    }
  }

  private void handleTask() throws NeedsTaskException {
    if (runTasks) {
      engine.getDelegatedTask().run();
    } else {
      throw new NeedsTaskException(engine.getDelegatedTask());
    }
  }

  /** Copies bytes from the internal input plain buffer to the supplied buffer. */
  private int transferPendingPlain(ByteBufferSet dstBuffers) {
    inPlain.buffer.flip(); // will read
    int bytes = dstBuffers.putRemaining(inPlain.buffer);
    inPlain.buffer.compact(); // will write
    boolean disposed = inPlain.release();
    if (!disposed) {
      inPlain.zeroRemaining();
    }
    return bytes;
  }

  private void unwrapLoop(HandshakeStatus originalStatus) throws SSLException {

    ByteBufferSet effDest;
    if (suppliedInPlain != null) {
      effDest = suppliedInPlain;
    } else {
      inPlain.prepare();
      effDest = new ByteBufferSet(inPlain.buffer);
    }

    while (true) {
      Util.assertTrue(inPlain.nullOrEmpty());
      SSLEngineResult result = callEngineUnwrap(effDest);
      /*
       * Note that data can be returned even in case of overflow, in that
       * case, just return the data.
       */
      if (result.bytesProduced() > 0
          || result.getStatus() == Status.BUFFER_UNDERFLOW
          || result.getStatus() == Status.CLOSED
          || result.getHandshakeStatus() != originalStatus) {
        bytesToReturn = result.bytesProduced();
        if (result.getStatus() == Status.CLOSED) {
          shutdownReceived = true;
        }
        return;
      }
      if (result.getStatus() == Status.BUFFER_OVERFLOW) {
        if (suppliedInPlain != null && effDest == suppliedInPlain) {
          /*
           * The client-supplier buffer is not big enough. Use the
           * internal inPlain buffer, also ensure that it is bigger
           * than the too-small supplied one.
           */
          inPlain.prepare();
          ensureInPlainCapacity(
              Math.min(((int) suppliedInPlain.remaining()) * 2, maxTlsPacketSize));
        } else {
          inPlain.enlarge();
        }
        // inPlain changed, re-create the wrapper
        effDest = new ByteBufferSet(inPlain.buffer);
      }
    }
  }

  private SSLEngineResult callEngineUnwrap(ByteBufferSet dest) throws SSLException {
    inEncrypted.buffer.flip();
    try {
      SSLEngineResult result =
          engine.unwrap(inEncrypted.buffer, dest.array, dest.offset, dest.length);
      if (logger.isTraceEnabled()) {
        logger.trace(
            "engine.unwrap() result [{}]. Engine status: {}; inEncrypted {}; inPlain: {}",
            Util.resultToString(result),
            result.getHandshakeStatus(),
            inEncrypted,
            dest);
      }
      return result;
    } catch (SSLException e) {
      // something bad was received from the underlying channel, we cannot
      // continue
      invalid = true;
      throw e;
    } finally {
      inEncrypted.buffer.compact();
    }
  }

  private int readFromChannel() throws IOException, EofException {
    try {
      return readFromChannel(readChannel, inEncrypted.buffer);
    } catch (WouldBlockException e) {
      throw e;
    } catch (IOException e) {
      invalid = true;
      throw e;
    }
  }

  public static int readFromChannel(ReadableByteChannel readChannel, ByteBuffer buffer)
      throws IOException, EofException {
    Util.assertTrue(buffer.hasRemaining());
    logger.trace("Reading from channel");
    int c = readChannel.read(buffer); // IO block
    logger.trace("Read from channel; response: {}, buffer: {}", c, buffer);
    if (c == -1) {
      throw new EofException();
    }
    if (c == 0) {
      throw new NeedsReadException();
    }
    return c;
  }

  // write

  public long write(ByteBufferSet source) throws IOException {
    /*
     * Note that we should enter the write loop even in the case that the source buffer has no remaining bytes,
     * as it could be the case, in non-blocking usage, that the user is forced to call write again after the
     * underlying channel is available for writing, just to write pending encrypted bytes.
     */
    handshake();
    writeLock.lock();
    try {
      if (invalid || shutdownSent) {
        throw new ClosedChannelException();
      }
      return wrapAndWrite(source);
    } finally {
      writeLock.unlock();
    }
  }

  private long wrapAndWrite(ByteBufferSet source) throws IOException {
    long bytesToConsume = source.remaining();
    outEncrypted.prepare();
    try {
      while (true) {
        writeToChannel();
        if (source.remaining() == 0) {
          return bytesToConsume;
        }
        wrapLoop(source);
      }
    } finally {
      outEncrypted.release();
    }
  }

  /** Returns last {@link HandshakeStatus} of the loop */
  private void wrapLoop(ByteBufferSet source) throws SSLException {
    while (true) {
      SSLEngineResult result = callEngineWrap(source);
      switch (result.getStatus()) {
        case OK:
        case CLOSED:
          return;
        case BUFFER_OVERFLOW:
          Util.assertTrue(result.bytesConsumed() == 0);
          outEncrypted.enlarge();
          break;
        case BUFFER_UNDERFLOW:
          throw new IllegalStateException();
      }
    }
  }

  private SSLEngineResult callEngineWrap(ByteBufferSet source) throws SSLException {
    try {
      SSLEngineResult result =
          engine.wrap(source.array, source.offset, source.length, outEncrypted.buffer);
      if (logger.isTraceEnabled()) {
        logger.trace(
            "engine.wrap() result: [{}]; engine status: {}; srcBuffer: {}, outEncrypted: {}",
            Util.resultToString(result),
            result.getHandshakeStatus(),
            source,
            outEncrypted);
      }
      return result;
    } catch (SSLException e) {
      invalid = true;
      throw e;
    }
  }

  private void ensureInPlainCapacity(int newCapacity) {
    if (inPlain.buffer.capacity() < newCapacity) {
      logger.trace(
          "inPlain buffer too small, increasing from {} to {}",
          inPlain.buffer.capacity(),
          newCapacity);
      inPlain.resize(newCapacity);
    }
  }

  private void writeToChannel() throws IOException {
    if (outEncrypted.buffer.position() == 0) {
      return;
    }
    outEncrypted.buffer.flip();
    try {
      try {
        writeToChannel(writeChannel, outEncrypted.buffer);
      } catch (WouldBlockException e) {
        throw e;
      } catch (IOException e) {
        invalid = true;
        throw e;
      }
    } finally {
      outEncrypted.buffer.compact();
    }
  }

  private static void writeToChannel(WritableByteChannel channel, ByteBuffer src)
      throws IOException {
    while (src.hasRemaining()) {
      logger.trace("Writing to channel: {}", src);
      int c = channel.write(src);
      if (c == 0) {
        /*
         * If no bytesProduced were written, it means that the socket is
         * non-blocking and needs more buffer space, so stop the loop
         */
        throw new NeedsWriteException();
      }
      // blocking SocketChannels can write less than all the bytesProduced
      // just before an error the loop forces the exception
    }
  }

  // handshake and close

  /**
   * Force a new negotiation.
   *
   * @throws IOException if the underlying channel throws an IOException
   */
  public void renegotiate() throws IOException {
    /*
     * Renegotiation was removed in TLS 1.3. We have to do the check at this level because SSLEngine will not
     * check that, and just enter into undefined behavior.
     */
    // relying in hopefully-robust lexicographic ordering of protocol names
    if (engine.getSession().getProtocol().compareTo("TLSv1.3") >= 0) {
      throw new SSLException("renegotiation not supported in TLS 1.3 or latter");
    }
    try {
      doHandshake(true /* force */);
    } catch (EofException e) {
      throw new ClosedChannelException();
    }
  }

  /**
   * Do a negotiation if this connection is new and it hasn't been done already.
   *
   * @throws IOException if the underlying channel throws an IOException
   */
  public void handshake() throws IOException {
    try {
      doHandshake(false /* force */);
    } catch (EofException e) {
      throw new ClosedChannelException();
    }
  }

  private void doHandshake(boolean force) throws IOException, EofException {
    if (!force && negotiated) return;
    initLock.lock();
    try {
      if (invalid || shutdownSent) throw new ClosedChannelException();
      if (force || !negotiated) {
        engine.beginHandshake();
        logger.trace("Called engine.beginHandshake()");
        writeAndHandshake();

        if (engine.getSession().getProtocol().startsWith("DTLS")) {
          throw new IllegalArgumentException("DTLS not supported");
        }

        // call client code
        try {
          initSessionCallback.accept(engine.getSession());
        } catch (Exception e) {
          logger.trace("client code threw exception in session initialization callback", e);
          throw new TlsChannelCallbackException("session initialization callback failed", e);
        }
        negotiated = true;
      }
    } finally {
      initLock.unlock();
    }
  }

  private int writeAndHandshake() throws IOException, EofException {
    readLock.lock();
    try {
      writeLock.lock();
      try {
        Util.assertTrue(inPlain.nullOrEmpty());
        outEncrypted.prepare();
        try {
          writeToChannel(); // IO block
          return handshakeLoop();
        } finally {
          outEncrypted.release();
        }
      } finally {
        writeLock.unlock();
      }
    } finally {
      readLock.unlock();
    }
  }

  private int handshakeLoop() throws IOException, EofException {
    Util.assertTrue(inPlain.nullOrEmpty());
    while (true) {
      switch (engine.getHandshakeStatus()) {
        case NEED_WRAP:
          Util.assertTrue(outEncrypted.nullOrEmpty());
          wrapLoop(dummyOut);
          writeToChannel(); // IO block
          break;
        case NEED_UNWRAP:
          readAndUnwrap();
          if (bytesToReturn > 0) {
            return bytesToReturn;
          }
          break;
        case NOT_HANDSHAKING:
          /*
           * This should not really happen using SSLEngine, because
           * handshaking ends with a FINISHED status. However, we accept
           * this value to permit the use of a pass-through stub engine
           * with no encryption.
           */
          return 0;
        case NEED_TASK:
          handleTask();
          break;
        case FINISHED:
          return 0;
        default:
          // Unsupported stage eg: NEED_UNWRAP_AGAIN
          return 0;
      }
    }
  }

  private void readAndUnwrap() throws IOException, EofException {
    // Save status before operation: use it to stop when status changes
    HandshakeStatus orig = engine.getHandshakeStatus();
    inEncrypted.prepare();
    try {
      while (true) {
        Util.assertTrue(inPlain.nullOrEmpty());
        unwrapLoop(orig);
        if (bytesToReturn > 0 || engine.getHandshakeStatus() != orig || shutdownReceived) {
          return;
        }
        if (!inEncrypted.buffer.hasRemaining()) {
          inEncrypted.enlarge();
        }
        readFromChannel(); // IO block
      }
    } finally {
      inEncrypted.release();
    }
  }

  public void close() throws IOException {
    tryShutdown();
    writeChannel.close();
    readChannel.close();
    /*
     * After closing the underlying channels, locks should be taken fast.
     */
    readLock.lock();
    try {
      writeLock.lock();
      try {
        freeBuffers();
      } finally {
        writeLock.unlock();
      }
    } finally {
      readLock.unlock();
    }
  }

  private void tryShutdown() {
    if (!readLock.tryLock()) return;
    try {
      if (!writeLock.tryLock()) return;
      try {
        if (!shutdownSent) {
          try {
            boolean closed = shutdown();
            if (!closed && waitForCloseConfirmation) {
              shutdown();
            }
          } catch (Throwable e) {
            logger.debug("error doing TLS shutdown on close(), continuing: {}", e.getMessage());
          }
        }
      } finally {
        writeLock.unlock();
      }
    } finally {
      readLock.unlock();
    }
  }

  public boolean shutdown() throws IOException {
    readLock.lock();
    try {
      writeLock.lock();
      try {
        if (invalid) {
          throw new ClosedChannelException();
        }
        if (!shutdownSent) {
          shutdownSent = true;
          outEncrypted.prepare();
          try {
            writeToChannel(); // IO block
            engine.closeOutbound();
            wrapLoop(dummyOut);
            writeToChannel(); // IO block
          } finally {
            outEncrypted.release();
          }
          /*
           * If this side is the first to send close_notify, then,
           * inbound is not done and false should be returned (so the
           * client waits for the response). If this side is the
           * second, then inbound was already done, and we can return
           * true.
           */
          if (shutdownReceived) {
            freeBuffers();
          }
          return shutdownReceived;
        }
        /*
         * If we reach this point, then we just have to read the close
         * notification from the client. Only try to do it if necessary,
         * to make this method idempotent.
         */
        if (!shutdownReceived) {
          try {
            // IO block
            readAndUnwrap();
            Util.assertTrue(shutdownReceived);
          } catch (EofException e) {
            throw new ClosedChannelException();
          }
        }
        freeBuffers();
        return true;
      } finally {
        writeLock.unlock();
      }
    } finally {
      readLock.unlock();
    }
  }

  private void freeBuffers() {
    if (inEncrypted != null) {
      inEncrypted.dispose();
      inEncrypted = null;
    }
    if (inPlain != null) {
      inPlain.dispose();
      inPlain = null;
    }
    if (outEncrypted != null) {
      outEncrypted.dispose();
      outEncrypted = null;
    }
  }

  public boolean isOpen() {
    return !invalid && writeChannel.isOpen() && readChannel.isOpen();
  }

  public static void checkReadBuffer(ByteBufferSet dest) {
    if (dest.isReadOnly()) throw new IllegalArgumentException();
  }

  public SSLEngine engine() {
    return engine;
  }

  public boolean getRunTasks() {
    return runTasks;
  }

  @Override
  public int read(ByteBuffer dst) throws IOException {
    return (int) read(new ByteBufferSet(dst));
  }

  @Override
  public int write(ByteBuffer src) throws IOException {
    return (int) write(new ByteBufferSet(src));
  }

  public boolean shutdownReceived() {
    return shutdownReceived;
  }

  public boolean shutdownSent() {
    return shutdownSent;
  }

  public ReadableByteChannel plainReadableChannel() {
    return readChannel;
  }

  public WritableByteChannel plainWritableChannel() {
    return writeChannel;
  }
}