/*
 * Copyright (c) 2015 Twitter, Inc. All rights reserved.
 * Licensed under the Apache License v2.0
 * http://www.apache.org/licenses/LICENSE-2.0
 */

package com.twitter.whiskey.nio;

import com.twitter.whiskey.futures.Inline;
import com.twitter.whiskey.util.Origin;

import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLException;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.SocketChannel;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Deque;

/**
 * An asynchronous TLS socket interface.
 *
 * @author Bill Gallagher
 */
public final class SSLSocket extends Socket {

    private static final ByteBuffer[] EMPTY_BUFFER_ARRAY = new ByteBuffer[0];

    private final SSLEngine engine;

    private final Deque<WriteFuture> handshakeWriteQueue = new ArrayDeque<>(32);
    private final Deque<ReadFuture> handshakeReadQueue = new ArrayDeque<>();

    private final ByteBuffer bufferedWrapped;

    public SSLSocket(Origin origin, RunLoop runLoop, SSLEngine engine) {
        super(origin, runLoop);
        this.engine = engine;
        this.engine.setUseClientMode(true);
        bufferedWrapped = ByteBuffer.allocate(engine.getSession().getPacketBufferSize());
    }

    @Override
    boolean isSecure() {
        return true;
    }

    @Override
    void finishConnect() throws IOException {
        // writing an empty buffer will initiate a handshake
        wrapHandshake();
    }

    private void wrapHandshake() throws IOException {
        ByteBuffer out = ByteBuffer.allocate(engine.getSession().getPacketBufferSize());

        SSLEngineResult result;
        do {
            result = engine.wrap(EMPTY_BUFFER_ARRAY, out);

            if (result.bytesProduced() > 0) {
                out.flip();
                handshakeWriteQueue.add(new WriteFuture(new ByteBuffer[] { out }));
                out = ByteBuffer.allocate(engine.getSession().getPacketBufferSize());
            }

            switch (result.getHandshakeStatus()) {
                case FINISHED:
                    super.finishConnect();
                    break;
                case NEED_TASK:
                    runDelegatedTasks(engine);
                    break;
                case NEED_UNWRAP:
                    readAndUnwrapHandshake();
                    break;
                case NEED_WRAP:
                case NOT_HANDSHAKING:
                    break;
            }
        } while (result.bytesProduced() > 0);
    }

    private void unwrapHandshake(ByteBuffer wrappedBuf) throws IOException {

        SSLEngineResult result;
        do {
            // TODO(bgallagher) buffer pooling
            ByteBuffer to = ByteBuffer.allocate(engine.getSession().getPacketBufferSize());

            bufferedWrapped.put(wrappedBuf);
            bufferedWrapped.flip();
            result = engine.unwrap(bufferedWrapped, to);
            bufferedWrapped.compact();

            switch (result.getHandshakeStatus()) {
                case NEED_UNWRAP:
                    break;
                case NEED_WRAP:
                    wrapHandshake();
                    return;
                case NEED_TASK:
                    runDelegatedTasks(engine);
                    break;
                case FINISHED:
                    super.finishConnect();
                    if (bufferedWrapped.position() > 0) {
                        onReadable();
                    }
                    return;
                case NOT_HANDSHAKING:
                    break;
            }
        } while (result.getStatus() != SSLEngineResult.Status.BUFFER_UNDERFLOW);

        readAndUnwrapHandshake();
    }

    private void readAndUnwrapHandshake() {
        ReadFuture readFuture = super.read();
        readFuture.addListener(new Inline.Listener<ByteBuffer>() {

            @Override
            public void onComplete(ByteBuffer result) {
                try {
                    unwrapHandshake(result);
                } catch (IOException ioe) {
                    failConnect(ioe);
                }
            }

            @Override
            public void onError(Throwable throwable) {
                failConnect(throwable);
            }
        });
    }

    private static void runDelegatedTasks(SSLEngine engine) {
        Runnable task;
        while ((task = engine.getDelegatedTask()) != null) {
            task.run();
        }
    }

    @Override
    public ReadFuture read(ByteBuffer readBuffer) {
        return read(new SSLReadFuture(readBuffer));
    }

    @Override
    public ReadFuture read() {
        return read(new SSLReadFuture());
    }

    @Override
    public WriteFuture write(ByteBuffer[] data) {
        return write(new SSLWriteFuture(data));
    }

    @Override
    protected Deque<ReadFuture> getReadQueue() {
        if (engine.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) {
            return super.getReadQueue();
        } else {
            return handshakeReadQueue;
        }
    }

    @Override
    protected Deque<WriteFuture> getWriteQueue() {
        if (engine.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) {
            return super.getWriteQueue();
        } else {
            return handshakeWriteQueue;
        }
    }

    @Override
    public void close() {
        engine.closeOutbound();
        super.close();
    }

    private final class SSLReadFuture extends ReadFuture {

        SSLReadFuture() {
            super();
        }

        SSLReadFuture(ByteBuffer buffer) {
            super(buffer);
        }

        @Override
        boolean doRead(SocketChannel channel) throws IOException {

            ByteBuffer out = getBuffer();

            int bytesRead = channel.read(bufferedWrapped);

            if (bytesRead < 0) {
                fail(new IOException("connection closed"));
                return true;
            }

            bufferedWrapped.flip();

            SSLEngineResult.Status status = SSLEngineResult.Status.OK;
            while (out.remaining() > 0 && bufferedWrapped.remaining() > 0 && status ==
                SSLEngineResult.Status.OK) {

                SSLEngineResult result = engine.unwrap(bufferedWrapped, out);
                status = result.getStatus();

                if (result.getHandshakeStatus() != SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) {
                    throw new SSLException("renegotiation not supported");
                }

            }

            bufferedWrapped.compact();

            out.flip();
            set(out);
            return true;
        }
    }

    private final class SSLWriteFuture extends WriteFuture {

        private boolean wrapped = false;

        SSLWriteFuture(ByteBuffer[] data) {
            super(data);
        }

        private void wrap() throws IOException {

            ArrayList<ByteBuffer> wrapped = new ArrayList<>();

            while (true) {
                // TODO(bgallagher) buffer pooling
                ByteBuffer out = ByteBuffer.allocate(engine.getSession().getPacketBufferSize());

                SSLEngineResult result = engine.wrap(pending(), out);

                if (result.getHandshakeStatus() != SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) {
                    throw new SSLException("renegotiation not supported");
                }

                if (result.bytesProduced() > 0) {
                    out.flip();
                    wrapped.add(out);
                } else {
                    break;
                }
            }

            setPending(wrapped.toArray(new ByteBuffer[wrapped.size()]));
        }

        boolean doWrite() throws IOException {
            if (!wrapped) {
                wrap();
                wrapped = true;
            }

            return super.doWrite();
        }
    }
}