/*
 * JBoss, Home of Professional Open Source.
 * Copyright 2014 Red Hat, Inc., and individual contributors
 * as indicated by the @author tags.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */

package io.undertow.server.protocol.ajp;

import static org.xnio.Bits.anyAreSet;
import static org.xnio.Bits.longBitMask;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.FileChannel;
import java.util.concurrent.TimeUnit;

import io.undertow.UndertowMessages;
import io.undertow.conduits.ConduitListener;
import io.undertow.server.Connectors;
import io.undertow.server.HttpServerExchange;
import io.undertow.util.ImmediatePooledByteBuffer;
import org.xnio.IoUtils;
import org.xnio.channels.StreamSinkChannel;
import org.xnio.conduits.AbstractStreamSourceConduit;
import org.xnio.conduits.ConduitReadableByteChannel;
import org.xnio.conduits.StreamSourceConduit;

/**
 * Underlying AJP request channel.
 *
 * @author Stuart Douglas
 */
public class AjpServerRequestConduit extends AbstractStreamSourceConduit<StreamSourceConduit> {

    private static final ByteBuffer READ_BODY_CHUNK;

    static {
        ByteBuffer readBody = ByteBuffer.allocateDirect(7);
        readBody.put((byte) 'A');
        readBody.put((byte) 'B');
        readBody.put((byte) 0);
        readBody.put((byte) 3);
        readBody.put((byte) 6);
        readBody.put((byte) 0x1F);
        readBody.put((byte) 0xFA);
        readBody.flip();
        READ_BODY_CHUNK = readBody;

    }

    private static final int HEADER_LENGTH = 6;

    /**
     * There is a packet coming from apache.
     */
    private static final long STATE_READING = 1L << 63L;
    /**
     * There is no packet coming, as we need to set a GET_BODY_CHUNK message
     */
    private static final long STATE_SEND_REQUIRED = 1L << 62L;
    /**
     * read is done
     */
    private static final long STATE_FINISHED = 1L << 61L;

    /**
     * The remaining bits are used to store the remaining chunk size.
     */
    private static final long STATE_MASK = longBitMask(0, 60);


    private final HttpServerExchange exchange;

    private final AjpServerResponseConduit ajpResponseConduit;

    /**
     * byte buffer that is used to hold header data
     */
    private final ByteBuffer headerBuffer = ByteBuffer.allocateDirect(HEADER_LENGTH);

    private final ConduitListener<? super AjpServerRequestConduit> finishListener;

    /**
    /**
     * The total amount of remaining data. If this is unknown it is -1.
     */
    private long remaining;

    /**
     * State flags, with the chunk remaining stored in the low bytes
     */
    private long state;

    /**
     * The total amount of data that has been read
     */
    private long totalRead;

    public AjpServerRequestConduit(final StreamSourceConduit delegate, HttpServerExchange exchange, AjpServerResponseConduit ajpResponseConduit, Long size, ConduitListener<? super AjpServerRequestConduit> finishListener) {
        super(delegate);
        this.exchange = exchange;
        this.ajpResponseConduit = ajpResponseConduit;
        this.finishListener = finishListener;
        if (size == null) {
            state = STATE_SEND_REQUIRED;
            remaining = -1;
        } else if (size.longValue() == 0L) {
            state = STATE_FINISHED;
            remaining = 0;
        } else {
            state = STATE_READING;
            remaining = size;
        }
    }

    @Override
    public long transferTo(long position, long count, FileChannel target) throws IOException {
        try {
            return target.transferFrom(new ConduitReadableByteChannel(this), position, count);
        } catch (IOException | RuntimeException e) {
            IoUtils.safeClose(exchange.getConnection());
            throw e;
        }
    }

    @Override
    public long transferTo(long count, ByteBuffer throughBuffer, StreamSinkChannel target) throws IOException {
        try {
            return IoUtils.transfer(new ConduitReadableByteChannel(this), count, throughBuffer, target);
        } catch (IOException | RuntimeException e) {
            IoUtils.safeClose(exchange.getConnection());
            throw e;
        }
    }

    @Override
    public void terminateReads() throws IOException {
        if(exchange.isPersistent() && anyAreSet(state, STATE_FINISHED)) {
            return;
        }
        super.terminateReads();
    }

    @Override
    public long read(ByteBuffer[] dsts, int offset, int length) throws IOException {
        try {
            long total = 0;
            for (int i = offset; i < length; ++i) {
                while (dsts[i].hasRemaining()) {
                    int r = read(dsts[i]);
                    if (r <= 0 && total > 0) {
                        return total;
                    } else if (r <= 0) {
                        return r;
                    } else {
                        total += r;
                    }
                }
            }
            return total;
        } catch (IOException | RuntimeException e) {
            IoUtils.safeClose(exchange.getConnection());
            throw e;
        }
    }

    @Override
    public int read(ByteBuffer dst) throws IOException {
        try {
            long state = this.state;
            if (anyAreSet(state, STATE_FINISHED)) {
                return -1;
            } else if (anyAreSet(state, STATE_SEND_REQUIRED)) {
                state = this.state = (state & STATE_MASK) | STATE_READING;
                if (ajpResponseConduit.isWriteShutdown()) {
                    this.state = STATE_FINISHED;
                    if (finishListener != null) {
                        finishListener.handleEvent(this);
                    }
                    return -1;
                }
                if (!ajpResponseConduit.doGetRequestBodyChunk(READ_BODY_CHUNK.duplicate(), this)) {
                    return 0;
                }
            }
            //we might have gone into state_reading above
            if (anyAreSet(state, STATE_READING)) {
                return doRead(dst, state);
            }
            assert STATE_FINISHED == state;
            return -1;
        } catch (IOException | RuntimeException e) {
            IoUtils.safeClose(exchange.getConnection());
            throw e;
        }
    }

    private int doRead(final ByteBuffer dst, long state) throws IOException {
        ByteBuffer headerBuffer = this.headerBuffer;
        long headerRead = HEADER_LENGTH - headerBuffer.remaining();
        long remaining = this.remaining;
        if (remaining == 0) {
            this.state = STATE_FINISHED;
            if (finishListener != null) {
                finishListener.handleEvent(this);
            }
            return -1;
        }
        long chunkRemaining;
        if (headerRead != HEADER_LENGTH) {
            int read = next.read(headerBuffer);
            if (read == -1) {
                this.state = STATE_FINISHED;
                if (finishListener != null) {
                    finishListener.handleEvent(this);
                }
                throw new ClosedChannelException();
            } else if (headerBuffer.hasRemaining()) {
                if(headerBuffer.remaining() <= 2) {
                    //mod_jk can send 12 34 00 00 rather than 12 34 00 02 00 00

                    byte b1 = headerBuffer.get(0); //0x12
                    byte b2 = headerBuffer.get(1); //0x34
                    if (b1 != 0x12 || b2 != 0x34) {
                        throw UndertowMessages.MESSAGES.wrongMagicNumber((b1 & 0xFF) << 8 | (b2 & 0xFF));
                    }
                    b1 = headerBuffer.get(2);//the length headers, two more than the string length header
                    b2 = headerBuffer.get(3);
                    int totalSize = ((b1 & 0xFF) << 8) | (b2 & 0xFF);
                    if(totalSize == 0) {
                        if(headerBuffer.remaining() < 2) {
                            byte[] data = new byte[1];
                            ByteBuffer bb = ByteBuffer.wrap(data);
                            bb.put(headerBuffer.get(4));
                            bb.flip();
                            Connectors.ungetRequestBytes(exchange, new ImmediatePooledByteBuffer(bb));
                        }
                        this.remaining = 0;
                        this.state = STATE_FINISHED;

                        if (finishListener != null) {
                            finishListener.handleEvent(this);
                        }
                        return -1;
                    }
                }

                return 0;
            } else {
                headerBuffer.flip();
                byte b1 = headerBuffer.get(); //0x12
                byte b2 = headerBuffer.get(); //0x34
                if (b1 != 0x12 || b2 != 0x34) {
                    throw UndertowMessages.MESSAGES.wrongMagicNumber((b1 & 0xFF) << 8 | (b2 & 0xFF));
                }
                b1 = headerBuffer.get();//the length headers, two more than the string length header
                b2 = headerBuffer.get();
                int totalSize = ((b1 & 0xFF) << 8) | (b2 & 0xFF);
                if(totalSize == 0) {
                    byte[] data = new byte[2];
                    ByteBuffer bb = ByteBuffer.wrap(data);
                    bb.put(headerBuffer);
                    bb.flip();
                    Connectors.ungetRequestBytes(exchange, new ImmediatePooledByteBuffer(bb));
                    this.remaining = 0;
                    this.state = STATE_FINISHED;

                    if (finishListener != null) {
                        finishListener.handleEvent(this);
                    }
                    return -1;
                }

                b1 = headerBuffer.get();
                b2 = headerBuffer.get();
                chunkRemaining = ((b1 & 0xFF) << 8) | (b2 & 0xFF);
                if (chunkRemaining == 0) {
                    this.remaining = 0;
                    this.state = STATE_FINISHED;

                    if (finishListener != null) {
                        finishListener.handleEvent(this);
                    }
                    return -1;
                }
            }
        } else {
            chunkRemaining = this.state & STATE_MASK;
        }

        int limit = dst.limit();
        try {
            if (dst.remaining() > chunkRemaining) {
                dst.limit((int) (dst.position() + chunkRemaining));
            }
            int read = next.read(dst);
            chunkRemaining -= read;
            if (remaining != -1) {
                remaining -= read;
            }
            this.totalRead += read;
            if (remaining != 0) {
                if (chunkRemaining == 0) {
                    headerBuffer.clear();
                    this.state = STATE_SEND_REQUIRED;
                } else {
                    this.state = (state & ~STATE_MASK) | chunkRemaining;
                }
            }
            return read;
        } finally {
            this.remaining = remaining;
            dst.limit(limit);
            final long maxEntitySize = exchange.getMaxEntitySize();
            if (maxEntitySize > 0) {
                if (totalRead > maxEntitySize) {
                    //kill the connection, nothing else can be sent on it
                    terminateReads();
                    exchange.setPersistent(false);
                    throw UndertowMessages.MESSAGES.requestEntityWasTooLarge(maxEntitySize);
                }
            }
        }
    }

    @Override
    public void awaitReadable() throws IOException {
        try {
            if (anyAreSet(state, STATE_READING)) {
                next.awaitReadable();
            }
        } catch (IOException | RuntimeException e) {
            IoUtils.safeClose(exchange.getConnection());
            throw e;
        }
    }

    @Override
    public void awaitReadable(long time, TimeUnit timeUnit) throws IOException {
        try {
            if (anyAreSet(state, STATE_READING)) {
                next.awaitReadable(time, timeUnit);
            }
        } catch (IOException | RuntimeException e) {
            IoUtils.safeClose(exchange.getConnection());
            throw e;
        }
    }

    /**
     * Method that is called when an error occurs writing out read body chunk frames
     * @param e
     */
    void setReadBodyChunkError(IOException e) {
        IoUtils.safeClose(exchange.getConnection());
        if(isReadResumed()) {
            wakeupReads();
        }
    }
}