/*
 * JBoss, Home of Professional Open Source.
 * Copyright 2014 Red Hat, Inc., and individual contributors
 * as indicated by the @author tags.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */

package io.undertow.conduits;

import static org.xnio.Bits.allAreClear;
import static org.xnio.Bits.allAreSet;
import static org.xnio.Bits.anyAreSet;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.FileChannel;
import java.util.concurrent.TimeUnit;
import java.util.zip.Deflater;

import io.undertow.server.Connectors;
import org.xnio.IoUtils;
import io.undertow.connector.PooledByteBuffer;
import org.xnio.XnioIoThread;
import org.xnio.XnioWorker;
import org.xnio.channels.StreamSourceChannel;
import org.xnio.conduits.ConduitWritableByteChannel;
import org.xnio.conduits.Conduits;
import org.xnio.conduits.StreamSinkConduit;
import org.xnio.conduits.WriteReadyHandler;

import io.undertow.UndertowLogger;
import io.undertow.server.HttpServerExchange;
import io.undertow.util.ConduitFactory;
import io.undertow.util.NewInstanceObjectPool;
import io.undertow.util.ObjectPool;
import io.undertow.util.Headers;
import io.undertow.util.PooledObject;
import io.undertow.util.SimpleObjectPool;

/**
 * Channel that handles deflate compression
 *
 * @author Stuart Douglas
 */
public class DeflatingStreamSinkConduit implements StreamSinkConduit {

    protected volatile Deflater deflater;

    protected final PooledObject<Deflater> pooledObject;
    private final ConduitFactory<StreamSinkConduit> conduitFactory;
    private final HttpServerExchange exchange;

    private StreamSinkConduit next;
    private WriteReadyHandler writeReadyHandler;


    /**
     * The streams buffer. This is freed when the next is shutdown
     */
    protected PooledByteBuffer currentBuffer;
    /**
     * there may have been some additional data that did not fit into the first buffer
     */
    private ByteBuffer additionalBuffer;

    private int state = 0;

    private static final int SHUTDOWN = 1;
    private static final int NEXT_SHUTDOWN = 1 << 1;
    private static final int FLUSHING_BUFFER = 1 << 2;
    private static final int WRITES_RESUMED = 1 << 3;
    private static final int CLOSED = 1 << 4;
    private static final int WRITTEN_TRAILER = 1 << 5;

    public DeflatingStreamSinkConduit(final ConduitFactory<StreamSinkConduit> conduitFactory, final HttpServerExchange exchange) {
        this(conduitFactory, exchange, Deflater.DEFLATED);
    }

    public DeflatingStreamSinkConduit(final ConduitFactory<StreamSinkConduit> conduitFactory, final HttpServerExchange exchange, int deflateLevel) {
        this(conduitFactory, exchange, newInstanceDeflaterPool(deflateLevel));
    }

    public DeflatingStreamSinkConduit(final ConduitFactory<StreamSinkConduit> conduitFactory, final HttpServerExchange exchange, ObjectPool<Deflater> deflaterPool) {
        this.pooledObject = deflaterPool.allocate();
        this.deflater = pooledObject.getObject();
        this.currentBuffer = exchange.getConnection().getByteBufferPool().allocate();
        this.exchange = exchange;
        this.conduitFactory = conduitFactory;
        setWriteReadyHandler(new WriteReadyHandler.ChannelListenerHandler<>(Connectors.getConduitSinkChannel(exchange)));
    }

    public static ObjectPool<Deflater> newInstanceDeflaterPool(int deflateLevel) {
        return new NewInstanceObjectPool<Deflater>(() -> new Deflater(deflateLevel, true), Deflater::end);
    }

    public static ObjectPool<Deflater> simpleDeflaterPool(int poolSize, int deflateLevel) {
        return new SimpleObjectPool<Deflater>(poolSize, () -> new Deflater(deflateLevel, true), Deflater::reset, Deflater::end);
    }


    @Override
    public int write(final ByteBuffer src) throws IOException {
        if (anyAreSet(state, SHUTDOWN | CLOSED) || currentBuffer == null) {
            throw new ClosedChannelException();
        }
        try {
            if (!performFlushIfRequired()) {
                return 0;
            }
            if (src.remaining() == 0) {
                return 0;
            }
            //we may already have some input, if so compress it
            if (!deflater.needsInput()) {
                deflateData(false);
                if (!deflater.needsInput()) {
                    return 0;
                }
            }
            byte[] data = new byte[src.remaining()];
            src.get(data);
            preDeflate(data);
            deflater.setInput(data);
            Connectors.updateResponseBytesSent(exchange, 0 - data.length);
            deflateData(false);
            return data.length;
        } catch (IOException | RuntimeException | Error e) {
            freeBuffer();
            throw e;
        }
    }

    protected void preDeflate(byte[] data) {

    }

    @Override
    public long write(final ByteBuffer[] srcs, final int offset, final int length) throws IOException {
        if (anyAreSet(state, SHUTDOWN | CLOSED) || currentBuffer == null) {
            throw new ClosedChannelException();
        }
        try {
            int total = 0;
            for (int i = offset; i < offset + length; ++i) {
                if (srcs[i].hasRemaining()) {
                    int ret = write(srcs[i]);
                    total += ret;
                    if (ret == 0) {
                        return total;
                    }
                }
            }
            return total;
        } catch (IOException | RuntimeException | Error e) {
            freeBuffer();
            throw e;
        }
    }

    @Override
    public int writeFinal(ByteBuffer src) throws IOException {
        return Conduits.writeFinalBasic(this, src);
    }

    @Override
    public long writeFinal(ByteBuffer[] srcs, int offset, int length) throws IOException {
        return Conduits.writeFinalBasic(this, srcs, offset, length);
    }

    @Override
    public long transferFrom(final FileChannel src, final long position, final long count) throws IOException {
        if (anyAreSet(state, SHUTDOWN | CLOSED)) {
            throw new ClosedChannelException();
        }
        if (!performFlushIfRequired()) {
            return 0;
        }
        return src.transferTo(position, count, new ConduitWritableByteChannel(this));
    }


    @Override
    public long transferFrom(final StreamSourceChannel source, final long count, final ByteBuffer throughBuffer) throws IOException {
        if (anyAreSet(state, SHUTDOWN | CLOSED)) {
            throw new ClosedChannelException();
        }
        if (!performFlushIfRequired()) {
            return 0;
        }
        return IoUtils.transfer(source, count, throughBuffer, new ConduitWritableByteChannel(this));
    }

    @Override
    public XnioWorker getWorker() {
        return exchange.getConnection().getWorker();
    }

    @Override
    public void suspendWrites() {
        if (next == null) {
            state = state & ~WRITES_RESUMED;
        } else {
            next.suspendWrites();
        }
    }


    @Override
    public boolean isWriteResumed() {
        if (next == null) {
            return anyAreSet(state, WRITES_RESUMED);
        } else {
            return next.isWriteResumed();
        }
    }

    @Override
    public void wakeupWrites() {
        if (next == null) {
            resumeWrites();
        } else {
            next.wakeupWrites();
        }
    }

    @Override
    public void resumeWrites() {
        if (next == null) {
            state |= WRITES_RESUMED;
            queueWriteListener();
        } else {
            next.resumeWrites();
        }
    }

    private void queueWriteListener() {
        exchange.getConnection().getIoThread().execute(new Runnable() {
            @Override
            public void run() {
                if (writeReadyHandler != null) {
                    try {
                        writeReadyHandler.writeReady();
                    } finally {
                        //if writes are still resumed queue up another one
                        if (next == null && isWriteResumed()) {
                            queueWriteListener();
                        }
                    }
                }
            }
        });
    }

    @Override
    public void terminateWrites() throws IOException {
        if (deflater != null) {
            deflater.finish();
        }
        state |= SHUTDOWN;
    }

    @Override
    public boolean isWriteShutdown() {
        return anyAreSet(state, SHUTDOWN);
    }

    @Override
    public void awaitWritable() throws IOException {
        if (next == null) {
            return;
        } else {
            next.awaitWritable();
        }
    }

    @Override
    public void awaitWritable(final long time, final TimeUnit timeUnit) throws IOException {
        if (next == null) {
            return;
        } else {
            next.awaitWritable(time, timeUnit);
        }
    }

    @Override
    public XnioIoThread getWriteThread() {
        return exchange.getConnection().getIoThread();
    }

    @Override
    public void setWriteReadyHandler(final WriteReadyHandler handler) {
        this.writeReadyHandler = handler;
    }

    @Override
    public boolean flush() throws IOException {
        if (currentBuffer == null) {
            if (anyAreSet(state, NEXT_SHUTDOWN)) {
                return next.flush();
            } else {
                return true;
            }
        }
        try {
            boolean nextCreated = false;
            try {
                if (anyAreSet(state, SHUTDOWN)) {
                    if (anyAreSet(state, NEXT_SHUTDOWN)) {
                        return next.flush();
                    } else {
                        if (!performFlushIfRequired()) {
                            return false;
                        }
                        //if the deflater has not been fully flushed we need to flush it
                        if (!deflater.finished()) {
                            deflateData(false);
                            //if could not fully flush
                            if (!deflater.finished()) {
                                return false;
                            }
                        }
                        final ByteBuffer buffer = currentBuffer.getBuffer();
                        if (allAreClear(state, WRITTEN_TRAILER)) {
                            state |= WRITTEN_TRAILER;
                            byte[] data = getTrailer();
                            if (data != null) {
                                Connectors.updateResponseBytesSent(exchange, data.length);
                                if(additionalBuffer != null) {
                                    byte[] newData = new byte[additionalBuffer.remaining() + data.length];
                                    int pos = 0;
                                    while (additionalBuffer.hasRemaining()) {
                                        newData[pos++] = additionalBuffer.get();
                                    }
                                    for (byte aData : data) {
                                        newData[pos++] = aData;
                                    }
                                    this.additionalBuffer = ByteBuffer.wrap(newData);
                                } else if(anyAreSet(state, FLUSHING_BUFFER) && buffer.capacity() - buffer.remaining() >= data.length) {
                                    buffer.compact();
                                    buffer.put(data);
                                    buffer.flip();
                                } else if (data.length <= buffer.remaining() && !anyAreSet(state, FLUSHING_BUFFER)) {
                                    buffer.put(data);
                                } else {
                                    additionalBuffer = ByteBuffer.wrap(data);
                                }
                            }
                        }

                        //ok the deflater is flushed, now we need to flush the buffer
                        if (!anyAreSet(state, FLUSHING_BUFFER)) {
                            buffer.flip();
                            state |= FLUSHING_BUFFER;
                            if (next == null) {
                                nextCreated = true;
                                this.next = createNextChannel();
                            }
                        }
                        if (performFlushIfRequired()) {
                            state |= NEXT_SHUTDOWN;
                            freeBuffer();
                            next.terminateWrites();
                            return next.flush();
                        } else {
                            return false;
                        }
                    }
                } else {
                    if(allAreClear(state, FLUSHING_BUFFER)) {
                        if (next == null) {
                            nextCreated = true;
                            this.next = createNextChannel();
                        }
                        deflateData(true);
                        if(allAreClear(state, FLUSHING_BUFFER)) {
                            //deflateData can cause this to be change
                            currentBuffer.getBuffer().flip();
                            this.state |= FLUSHING_BUFFER;
                        }
                    }
                    if(!performFlushIfRequired()) {
                        return false;
                    }
                    return next.flush();
                }
            } finally {
                if (nextCreated) {
                    if (anyAreSet(state, WRITES_RESUMED) && !anyAreSet(state ,NEXT_SHUTDOWN)) {
                        try {
                            next.resumeWrites();
                        } catch (Throwable e) {
                            UndertowLogger.REQUEST_LOGGER.debug("Failed to resume", e);
                        }
                    }
                }
            }
        } catch (IOException | RuntimeException | Error e) {
            freeBuffer();
            throw e;
        }
    }

    /**
     * called before the stream is finally flushed.
     */
    protected byte[] getTrailer() {
        return null;
    }

    /**
     * The we are in the flushing state then we flush to the underlying stream, otherwise just return true
     *
     * @return false if there is still more to flush
     */
    private boolean performFlushIfRequired() throws IOException {
        if (anyAreSet(state, FLUSHING_BUFFER)) {
            final ByteBuffer[] bufs = new ByteBuffer[additionalBuffer == null ? 1 : 2];
            long totalLength = 0;
            bufs[0] = currentBuffer.getBuffer();
            totalLength += bufs[0].remaining();
            if (additionalBuffer != null) {
                bufs[1] = additionalBuffer;
                totalLength += bufs[1].remaining();
            }
            if (totalLength > 0) {
                long total = 0;
                long res = 0;
                do {
                    res = next.write(bufs, 0, bufs.length);
                    total += res;
                    if (res == 0) {
                        return false;
                    }
                } while (total < totalLength);
            }
            additionalBuffer = null;
            currentBuffer.getBuffer().clear();
            state = state & ~FLUSHING_BUFFER;
        }
        return true;
    }


    private StreamSinkConduit createNextChannel() {
        if (deflater.finished() && allAreSet(state, WRITTEN_TRAILER)) {
            //the deflater was fully flushed before we created the channel. This means that what is in the buffer is
            //all there is
            int remaining = currentBuffer.getBuffer().remaining();
            if (additionalBuffer != null) {
                remaining += additionalBuffer.remaining();
            }
            if(!exchange.getResponseHeaders().contains(Headers.TRANSFER_ENCODING)) {
                exchange.getResponseHeaders().put(Headers.CONTENT_LENGTH, Integer.toString(remaining));
            }
        } else {
            exchange.getResponseHeaders().remove(Headers.CONTENT_LENGTH);
        }
        return conduitFactory.create();
    }

    /**
     * Runs the current data through the deflater. As much as possible this will be buffered in the current output
     * stream.
     *
     * @throws IOException
     */
    private void deflateData(boolean force) throws IOException {
        //we don't need to flush here, as this should have been called already by the time we get to
        //this point
        boolean nextCreated = false;
        try (PooledByteBuffer arrayPooled = this.exchange.getConnection().getByteBufferPool().getArrayBackedPool().allocate()) {
            PooledByteBuffer pooled = this.currentBuffer;
            final ByteBuffer outputBuffer = pooled.getBuffer();

            final boolean shutdown = anyAreSet(state, SHUTDOWN);
            ByteBuffer buf = arrayPooled.getBuffer();
            while (force || !deflater.needsInput() || (shutdown && !deflater.finished())) {
                int count = deflater.deflate(buf.array(), buf.arrayOffset(), buf.remaining(), force ? Deflater.SYNC_FLUSH: Deflater.NO_FLUSH);
                Connectors.updateResponseBytesSent(exchange, count);
                if (count != 0) {
                    int remaining = outputBuffer.remaining();
                    if (remaining > count) {
                        outputBuffer.put(buf.array(), buf.arrayOffset(), count);
                    } else {
                        if (remaining == count) {
                            outputBuffer.put(buf.array(), buf.arrayOffset(), count);
                        } else {
                            outputBuffer.put(buf.array(), buf.arrayOffset(), remaining);
                            additionalBuffer = ByteBuffer.allocate(count - remaining);
                            additionalBuffer.put(buf.array(), buf.arrayOffset() + remaining, count - remaining);
                            additionalBuffer.flip();
                        }
                        outputBuffer.flip();
                        this.state |= FLUSHING_BUFFER;
                        if (next == null) {
                            nextCreated = true;
                            this.next = createNextChannel();
                        }
                        if (!performFlushIfRequired()) {
                            return;
                        }
                    }
                } else {
                    force = false;
                }
            }
        } finally {
            if (nextCreated) {
                if (anyAreSet(state, WRITES_RESUMED)) {
                    next.resumeWrites();
                }
            }
        }
    }


    @Override
    public void truncateWrites() throws IOException {
        freeBuffer();
        state |= CLOSED;
        next.truncateWrites();
    }

    private void freeBuffer() {
        if (currentBuffer != null) {
            currentBuffer.close();
            currentBuffer = null;
            state = state & ~FLUSHING_BUFFER;
        }
        if (deflater != null) {
            pooledObject.close();
            deflater = null;
        }
    }
}