/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.github.pgasync.netty;

import com.github.pgasync.PgProtocolStream;
import com.github.pgasync.message.Message;
import com.github.pgasync.message.backend.SSLHandshake;
import com.github.pgasync.message.frontend.SSLRequest;
import com.github.pgasync.message.frontend.StartupMessage;
import com.github.pgasync.message.frontend.Terminate;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.channel.*;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslHandshakeCompletionEvent;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;

import java.io.IOException;
import java.net.SocketAddress;
import java.nio.charset.Charset;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.function.Function;

/**
 * Netty messages stream to Postgres backend.
 *
 * @author Antti Laisi
 */
public class NettyProtocolStream extends PgProtocolStream {

    private final Bootstrap channelPipeline;
    private StartupMessage startupWith;
    private ChannelHandlerContext ctx;

    private final GenericFutureListener<Future<? super Object>> outboundErrorListener = written -> {
        if (!written.isSuccess()) {
            respondWithException(written.cause());
        }
    };

    public NettyProtocolStream(EventLoopGroup group, SocketAddress address, boolean useSsl, Charset encoding, Executor futuresExecutor) {
        super(address, useSsl, encoding, futuresExecutor);
        this.channelPipeline = new Bootstrap()
                .group(group)
                .channel(NioSocketChannel.class)
                .handler(newProtocolInitializer());
    }

    @Override
    public CompletableFuture<Message> connect(StartupMessage startup) {
        startupWith = startup;
        return offerRoundTrip(() -> channelPipeline.connect(address).addListener(outboundErrorListener), false)
                .thenApply(this::send)
                .thenCompose(Function.identity())
                .thenApply(message -> {
                    if (message == SSLHandshake.INSTANCE) {
                        return send(startup);
                    } else {
                        return CompletableFuture.completedFuture(message);
                    }
                })
                .thenCompose(Function.identity());
    }

    @Override
    public boolean isConnected() {
        return ctx.channel().isOpen();
    }

    @Override
    public CompletableFuture<Void> close() {
        CompletableFuture<Void> uponClose = new CompletableFuture<>();
        ctx.writeAndFlush(Terminate.INSTANCE).addListener(written -> {
            if (written.isSuccess()) {
                ctx.close().addListener(closed -> {
                    if (closed.isSuccess()) {
                        uponClose.completeAsync(() -> null, futuresExecutor);
                    } else {
                        Throwable th = closed.cause();
                        futuresExecutor.execute(() -> uponClose.completeExceptionally(th));
                    }
                });
            } else {
                Throwable th = written.cause();
                futuresExecutor.execute(() -> uponClose.completeExceptionally(th));
            }
        });
        return uponClose;
    }

    @Override
    protected void write(Message... messages) {
        for (Message message : messages) {
            ctx.write(message).addListener(outboundErrorListener);
        }
        ctx.flush();
    }

    private ChannelInitializer<Channel> newProtocolInitializer() {
        return new ChannelInitializer<>() {
            @Override
            protected void initChannel(Channel channel) {
                if (useSsl) {
                    channel.pipeline().addLast(newSslInitiator());
                }
                channel.pipeline().addLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 1, 4, -4, 0, true));
                channel.pipeline().addLast(new NettyMessageDecoder(encoding));
                channel.pipeline().addLast(new NettyMessageEncoder(encoding));
                channel.pipeline().addLast(newProtocolHandler());
            }
        };
    }

    private ChannelHandler newSslInitiator() {
        return new ByteToMessageDecoder() {
            @Override
            protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
                if (in.readableBytes() >= 1) {
                    if ('S' == in.readByte()) { // SSL supported response
                        ctx.pipeline().remove(this);
                        ctx.pipeline().addFirst(
                                SslContextBuilder
                                        .forClient()
                                        .trustManager(InsecureTrustManagerFactory.INSTANCE)
                                        .build()
                                        .newHandler(ctx.alloc()));
                    } else {
                        ctx.fireExceptionCaught(new IllegalStateException("SSL required but not supported by Postgres"));
                    }
                }
            }
        };
    }

    private ChannelHandler newProtocolHandler() {
        return new ChannelInboundHandlerAdapter() {

            @Override
            public void channelActive(ChannelHandlerContext context) {
                NettyProtocolStream.this.ctx = context;
                if (useSsl) {
                    respondWithMessage(SSLRequest.INSTANCE);
                } else {
                    respondWithMessage(startupWith);
                }
            }

            @Override
            public void userEventTriggered(ChannelHandlerContext context, Object evt) {
                if (evt instanceof SslHandshakeCompletionEvent && ((SslHandshakeCompletionEvent) evt).isSuccess()) {
                    respondWithMessage(SSLHandshake.INSTANCE);
                }
            }

            @Override
            public void channelRead(ChannelHandlerContext context, Object message) {
                if (message instanceof Message) {
                    respondWithMessage((Message) message);
                }
            }

            @Override
            public void channelInactive(ChannelHandlerContext context) {
                exceptionCaught(context, new IOException("Channel state changed to inactive"));
            }

            @Override
            public void exceptionCaught(ChannelHandlerContext context, Throwable cause) {
                respondWithException(cause);
            }
        };
    }

}