/*
 * Copyright (c) 2015 Twitter, Inc. All rights reserved.
 * Licensed under the Apache License v2.0
 * http://www.apache.org/licenses/LICENSE-2.0
 */

package com.twitter.whiskey.nio;

import com.twitter.whiskey.futures.CompletableFuture;
import com.twitter.whiskey.futures.Listener;
import com.twitter.whiskey.futures.ReactiveFuture;
import com.twitter.whiskey.net.Protocol;
import com.twitter.whiskey.util.Origin;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.SocketChannel;
import java.nio.channels.UnresolvedAddressException;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Deque;
import java.util.concurrent.TimeUnit;

/**
 * An asynchronous TCP socket interface.
 *
 * @author Michael Schore
 * @author Bill Gallagher
 */
public class Socket extends Selectable {

    private final Origin origin;
    private final RunLoop runLoop;

    private boolean closed = false;

    private SocketChannel channel;
    private SelectionKey key;

    private ConnectFuture connectFuture;
    private CloseFuture closeFuture;
    private Deque<ReadFuture> readQueue = new ArrayDeque<>(1);
    private Deque<WriteFuture> writeQueue = new ArrayDeque<>(32);

    public Socket(Origin origin, RunLoop runLoop) {
        this.origin = origin;
        this.runLoop = runLoop;
    }

    public ConnectFuture connect() {
        connectFuture = new ConnectFuture();

        runLoop.execute(new Runnable() {
            public void run() {
                try {
                    channel = SocketChannel.open();
                    channel.configureBlocking(false);
                    channel.connect(new InetSocketAddress(origin.getHost(), origin.getPort()));
                    reregister();
                } catch (IOException | UnresolvedAddressException e) {
                    connectFuture.fail(e);
                    closed = true;
                }
            }
        });

        return connectFuture;
    }

    public void addCloseListener(Listener<Void> listener) {
        closeFuture.addListener(listener);
    }

    public ReadFuture read() {
        return read(new ReadFuture());
    }

    public ReadFuture read(final ReadFuture readFuture) {
        runLoop.execute(new Runnable() {
            public void run() {
                getReadQueue().add(readFuture);

                if (channel != null && getReadQueue().size() == 1) {
                    reregister();
                }
            }
        });

        return readFuture;
    }

    public ReadFuture read(ByteBuffer readBuffer) {
        return read(new ReadFuture(readBuffer));
    }

    public ReadFuture read(int timeout, TimeUnit timeoutUnit) {
        return read();
    }

    public WriteFuture write(ByteBuffer data) {
        return write(new ByteBuffer[]{data});
    }

    public WriteFuture write(ByteBuffer[] data) {
        return write(new WriteFuture(data));
    }

    public WriteFuture write(ByteBuffer data, int timeout, TimeUnit timeoutUnit) {
        return write(new ByteBuffer[]{data});
    }

    public WriteFuture write(ByteBuffer[] data, int timeout, TimeUnit timeoutUnit) {
        return write(data);
    }

    public WriteFuture write(final WriteFuture writeFuture) {
        runLoop.execute(new Runnable() {
            public void run() {
                getWriteQueue().add(writeFuture);

                if (isConnected() && getWriteQueue().size() == 1) {
                    reregister();
                }
            }
        });

        return writeFuture;
    }

    protected Deque<ReadFuture> getReadQueue() {
        return readQueue;
    }
    
    protected Deque<WriteFuture> getWriteQueue() {
        return writeQueue;
    }

    @Override
    public void onConnect() {
        try {
            channel.finishConnect();
            finishConnect();
        } catch (IOException e) {
            connectFuture.fail(e);
            closed = true;
        }
    }

    void finishConnect() throws IOException {
        closeFuture = new CloseFuture();
        connectFuture.set(origin);
        reregister();
    }

    void failConnect(Throwable thr) {
        if (!connectFuture.isDone()) {
            connectFuture.fail(thr);
        }
    }

    @Override
    public void onReadable() {
        
        if (closed) {
            return;
        }

        Deque<ReadFuture> readQueue = getReadQueue();
        
        if (readQueue.isEmpty()) {
            reregister();
            return;
        }

        ReadFuture currentRead = readQueue.peek();
        assert (!currentRead.isDone());

        boolean complete;
        try {
            complete = currentRead.doRead(channel);
        } catch (IOException e) {
            close(e);
            return;
        }

        if (complete) {
            readQueue.poll();
        }

        reregister();
    }

    @Override
    public void onWriteable() {

        if (closed) {
            return;
        }

        Deque<WriteFuture> writeQueue = getWriteQueue();

        if (writeQueue.isEmpty()) {
            reregister();
            return;
        }

        WriteFuture currentWrite = writeQueue.peek();
        assert(!currentWrite.isDone());

        boolean complete;
        try {
             complete = currentWrite.doWrite();
        } catch (IOException e) {
            close(e);
            return;
        }

        if (complete) {
            getWriteQueue().poll();
        }

        reregister();
    }

    void reregister() {
        runLoop.register(interestSet(), this);
    }

    @Override
    public SocketChannel getChannel() {
        return channel;
    }

    @Override
    public void setSelectionKey(SelectionKey key) {
        this.key = key;
    }

    public boolean isConnected() {
        return !closed && channel != null && channel.isConnected();
    }

    private int interestSet() {
        if (channel.isConnectionPending()) return SelectionKey.OP_CONNECT;

        int interestSet = 0;
        if (!getReadQueue().isEmpty()) interestSet = SelectionKey.OP_READ;
        if (!getWriteQueue().isEmpty()) interestSet |= SelectionKey.OP_WRITE;
        return interestSet;
    }

    @Override
    public void onClose(Throwable e) {

        if (closed) return;
        closed = true;
        key = null;
        closeFuture.fail(e);
    }

    boolean isSecure() {
        return false;
    }

    private void close(Throwable e) {

        if (closed) return;
        closed = true;
        if (key != null) key.cancel();
        closeFuture.fail(e);
    }

    public void close() {

        if (closed) return;
        closed = true;
        if (key != null) key.cancel();

        try {
            channel.close();
        } catch (IOException ignored) {
        }

        closeFuture.set(null);
    }

    public Protocol getProtocol() {
        return Protocol.SPDY_3_1;
    }

    // TODO: split internal futures into public interface and package-private implementation
    public class ConnectFuture extends CompletableFuture<Origin> {
    }

    public class CloseFuture extends CompletableFuture<Void> {
    }

    public class ReadFuture extends CompletableFuture<ByteBuffer> {

        private static final int DEFAULT_BUFFER_SIZE = 18 * 1024;

        private final ByteBuffer buffer;

        ReadFuture() {
            this(ByteBuffer.allocate(DEFAULT_BUFFER_SIZE));
        }

        ReadFuture(ByteBuffer buffer) {
            this.buffer = buffer;
        }

        boolean doRead(SocketChannel channel) throws IOException {

            ByteBuffer buffer = getBuffer();

            int bytesRead = channel.read(buffer);

            assert (bytesRead != 0);

            if (bytesRead > 0) {
                buffer.flip();
                set(buffer);
            } else {
                fail(new IOException("connection closed"));
            }

            return true;
        }

        @Override
        public boolean cancel(boolean mayInterruptIfRunning) {
            return getReadQueue().contains(this) && super.cancel(mayInterruptIfRunning) && getReadQueue().remove(this);
        }

        public ByteBuffer getBuffer() {
            return buffer;
        }
    }

    public class WriteFuture extends ReactiveFuture<Long, Long> {
        private ByteBuffer[] data;
        ArrayList<Long> bytesWritten;
        Long totalBytesWritten;

        WriteFuture(ByteBuffer[] data) {
            this.data = data;
            bytesWritten = new ArrayList<>();
            totalBytesWritten = (long)0;
        }

        ByteBuffer[] pending() throws IOException {
            return data;
        }

        public void setPending(ByteBuffer[] pending) {
            this.data = pending;
        }

        boolean doWrite() throws IOException {
            long bytesWritten = channel.write(data);

            ByteBuffer finalData = data[data.length - 1];
            boolean writeComplete = finalData.position() == finalData.limit();
            provide(bytesWritten);
            if (writeComplete) finish();
            return writeComplete;
        }

        @Override
        protected void accumulate(Long element) {
            bytesWritten.add(element);
            totalBytesWritten += element;
        }

        @Override
        protected Iterable<Long> drain() {
            return bytesWritten;
        }

        @Override
        protected boolean complete() {
            return set(totalBytesWritten);
        }

        @Override
        public boolean cancel(boolean mayInterruptIfRunning) {
            return getWriteQueue().contains(this) && super.cancel(mayInterruptIfRunning) && getWriteQueue().remove(this);
        }

    }
}