/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.alibaba.dubbo.rpc.protocol.thrift;

import com.alibaba.dubbo.common.Constants;
import com.alibaba.dubbo.common.URL;
import com.alibaba.dubbo.remoting.Channel;
import com.alibaba.dubbo.remoting.buffer.ChannelBuffer;
import com.alibaba.dubbo.remoting.buffer.ChannelBuffers;
import com.alibaba.dubbo.remoting.exchange.Request;
import com.alibaba.dubbo.remoting.exchange.Response;
import com.alibaba.dubbo.remoting.exchange.support.DefaultFuture;
import com.alibaba.dubbo.rpc.RpcException;
import com.alibaba.dubbo.rpc.RpcInvocation;
import com.alibaba.dubbo.rpc.RpcResult;
import com.alibaba.dubbo.rpc.gen.thrift.Demo;
import com.alibaba.dubbo.rpc.protocol.thrift.io.RandomAccessByteArrayOutputStream;

import org.apache.thrift.TApplicationException;
import org.apache.thrift.protocol.TBinaryProtocol;
import org.apache.thrift.protocol.TMessage;
import org.apache.thrift.protocol.TMessageType;
import org.apache.thrift.transport.TFramedTransport;
import org.apache.thrift.transport.TIOStreamTransport;
import org.apache.thrift.transport.TTransport;
import org.junit.Assert;
import org.junit.Ignore;
import org.junit.Test;

import java.io.ByteArrayInputStream;

@Ignore
public class ThriftCodecTest {

    private ThriftCodec codec = new ThriftCodec();
    private Channel channel = new MockedChannel(URL.valueOf("thrift://127.0.0.1"));

    static byte[] encodeFrame(byte[] content) {
        byte[] result = new byte[4 + content.length];
        TFramedTransport.encodeFrameSize(content.length, result);
        System.arraycopy(content, 0, result, 4, content.length);
        return result;
    }

    @Test
    public void testEncodeRequest() throws Exception {

        Request request = createRequest();

        ChannelBuffer output = ChannelBuffers.dynamicBuffer(1024);

        codec.encode(channel, output, request);

        byte[] bytes = new byte[output.readableBytes()];
        output.readBytes(bytes);

        ByteArrayInputStream bis = new ByteArrayInputStream(bytes);

        TTransport transport = new TIOStreamTransport(bis);

        TBinaryProtocol protocol = new TBinaryProtocol(transport);

        // frame
        byte[] length = new byte[4];
        transport.read(length, 0, 4);

        if (bis.markSupported()) {
            bis.mark(0);
        }

        // magic
        Assert.assertEquals(ThriftCodec.MAGIC, protocol.readI16());

        // message length
        int messageLength = protocol.readI32();
        Assert.assertEquals(messageLength + 4, bytes.length);

        // header length
        short headerLength = protocol.readI16();
        // version
        Assert.assertEquals(ThriftCodec.VERSION, protocol.readByte());
        // service name
        Assert.assertEquals(Demo.Iface.class.getName(), protocol.readString());
        // dubbo request id
        Assert.assertEquals(request.getId(), protocol.readI64());

        // test message header length
        if (bis.markSupported()) {
            bis.reset();
            bis.skip(headerLength);
        }

        TMessage message = protocol.readMessageBegin();

        Demo.echoString_args args = new Demo.echoString_args();

        args.read(protocol);

        protocol.readMessageEnd();

        Assert.assertEquals("echoString", message.name);

        Assert.assertEquals(TMessageType.CALL, message.type);

        Assert.assertEquals("Hello, World!", args.getArg());

    }

    @Test
    public void testDecodeReplyResponse() throws Exception {

        URL url = URL.valueOf(ThriftProtocol.NAME + "://127.0.0.1:40880/" + Demo.Iface.class.getName());

        Channel channel = new MockedChannel(url);

        RandomAccessByteArrayOutputStream bos = new RandomAccessByteArrayOutputStream(128);

        Request request = createRequest();

        DefaultFuture future = new DefaultFuture(channel, request, 10);

        TMessage message = new TMessage("echoString", TMessageType.REPLY, ThriftCodec.getSeqId());

        Demo.echoString_result methodResult = new Demo.echoString_result();

        methodResult.success = "Hello, World!";

        TTransport transport = new TIOStreamTransport(bos);

        TBinaryProtocol protocol = new TBinaryProtocol(transport);

        int messageLength, headerLength;
        // prepare
        protocol.writeI16(ThriftCodec.MAGIC);
        protocol.writeI32(Integer.MAX_VALUE);
        protocol.writeI16(Short.MAX_VALUE);
        protocol.writeByte(ThriftCodec.VERSION);
        protocol.writeString(Demo.Iface.class.getName());
        protocol.writeI64(request.getId());
        protocol.getTransport().flush();
        headerLength = bos.size();

        protocol.writeMessageBegin(message);
        methodResult.write(protocol);
        protocol.writeMessageEnd();
        protocol.getTransport().flush();
        int oldIndex = messageLength = bos.size();

        try {
            bos.setWriteIndex(ThriftCodec.MESSAGE_LENGTH_INDEX);
            protocol.writeI32(messageLength);
            bos.setWriteIndex(ThriftCodec.MESSAGE_HEADER_LENGTH_INDEX);
            protocol.writeI16((short) (0xffff & headerLength));
        } finally {
            bos.setWriteIndex(oldIndex);
        }
        // prepare

        byte[] buf = new byte[4 + bos.size()];
        System.arraycopy(bos.toByteArray(), 0, buf, 4, bos.size());

        ChannelBuffer bis = ChannelBuffers.wrappedBuffer(buf);

        Object obj = codec.decode((Channel) null, bis);

        Assert.assertNotNull(obj);

        Assert.assertEquals(true, obj instanceof Response);

        Response response = (Response) obj;

        Assert.assertEquals(request.getId(), response.getId());

        Assert.assertTrue(response.getResult() instanceof RpcResult);

        RpcResult result = (RpcResult) response.getResult();

        Assert.assertTrue(result.getResult() instanceof String);

        Assert.assertEquals(methodResult.success, result.getResult());

    }

    @Test
    public void testDecodeExceptionResponse() throws Exception {

        URL url = URL.valueOf(ThriftProtocol.NAME + "://127.0.0.1:40880/" + Demo.class.getName());

        Channel channel = new MockedChannel(url);

        RandomAccessByteArrayOutputStream bos = new RandomAccessByteArrayOutputStream(128);

        Request request = createRequest();

        DefaultFuture future = new DefaultFuture(channel, request, 10);

        TMessage message = new TMessage("echoString", TMessageType.EXCEPTION, ThriftCodec.getSeqId());

        TTransport transport = new TIOStreamTransport(bos);

        TBinaryProtocol protocol = new TBinaryProtocol(transport);

        TApplicationException exception = new TApplicationException();

        int messageLength, headerLength;
        // prepare
        protocol.writeI16(ThriftCodec.MAGIC);
        protocol.writeI32(Integer.MAX_VALUE);
        protocol.writeI16(Short.MAX_VALUE);
        protocol.writeByte(ThriftCodec.VERSION);
        protocol.writeString(Demo.class.getName());
        protocol.writeI64(request.getId());
        protocol.getTransport().flush();
        headerLength = bos.size();

        protocol.writeMessageBegin(message);
        exception.write(protocol);
        protocol.writeMessageEnd();
        protocol.getTransport().flush();
        int oldIndex = messageLength = bos.size();

        try {
            bos.setWriteIndex(ThriftCodec.MESSAGE_LENGTH_INDEX);
            protocol.writeI32(messageLength);
            bos.setWriteIndex(ThriftCodec.MESSAGE_HEADER_LENGTH_INDEX);
            protocol.writeI16((short) (0xffff & headerLength));
        } finally {
            bos.setWriteIndex(oldIndex);
        }
        // prepare

        ChannelBuffer bis = ChannelBuffers.wrappedBuffer(encodeFrame(bos.toByteArray()));

        Object obj = codec.decode((Channel) null, bis);

        Assert.assertNotNull(obj);

        Assert.assertTrue(obj instanceof Response);

        Response response = (Response) obj;

        Assert.assertTrue(response.getResult() instanceof RpcResult);

        RpcResult result = (RpcResult) response.getResult();

        Assert.assertTrue(result.hasException());

        Assert.assertTrue(result.getException() instanceof RpcException);

    }

    @Test
    public void testEncodeReplyResponse() throws Exception {

        URL url = URL.valueOf(ThriftProtocol.NAME + "://127.0.0.1:40880/" + Demo.Iface.class.getName());

        Channel channel = new MockedChannel(url);

        Request request = createRequest();

        RpcResult rpcResult = new RpcResult();
        rpcResult.setResult("Hello, World!");

        Response response = new Response();
        response.setResult(rpcResult);
        response.setId(request.getId());
        ChannelBuffer bos = ChannelBuffers.dynamicBuffer(1024);

        ThriftCodec.RequestData rd = ThriftCodec.RequestData.create(
                ThriftCodec.getSeqId(), Demo.Iface.class.getName(), "echoString");
        ThriftCodec.cachedRequest.putIfAbsent(request.getId(), rd);
        codec.encode(channel, bos, response);

        byte[] buf = new byte[bos.writerIndex() - 4];
        System.arraycopy(bos.array(), 4, buf, 0, bos.writerIndex() - 4);

        ByteArrayInputStream bis = new ByteArrayInputStream(buf);

        if (bis.markSupported()) {
            bis.mark(0);
        }

        TIOStreamTransport transport = new TIOStreamTransport(bis);
        TBinaryProtocol protocol = new TBinaryProtocol(transport);

        Assert.assertEquals(ThriftCodec.MAGIC, protocol.readI16());
        Assert.assertEquals(protocol.readI32() + 4, bos.writerIndex());
        int headerLength = protocol.readI16();

        Assert.assertEquals(ThriftCodec.VERSION, protocol.readByte());
        Assert.assertEquals(Demo.Iface.class.getName(), protocol.readString());
        Assert.assertEquals(request.getId(), protocol.readI64());

        if (bis.markSupported()) {
            bis.reset();
            bis.skip(headerLength);
        }

        TMessage message = protocol.readMessageBegin();
        Assert.assertEquals("echoString", message.name);
        Assert.assertEquals(TMessageType.REPLY, message.type);
        Assert.assertEquals(ThriftCodec.getSeqId(), message.seqid);
        Demo.echoString_result result = new Demo.echoString_result();
        result.read(protocol);
        protocol.readMessageEnd();

        Assert.assertEquals(rpcResult.getValue(), result.getSuccess());
    }

    @Test
    public void testEncodeExceptionResponse() throws Exception {

        URL url = URL.valueOf(ThriftProtocol.NAME + "://127.0.0.1:40880/" + Demo.Iface.class.getName());

        Channel channel = new MockedChannel(url);

        Request request = createRequest();

        RpcResult rpcResult = new RpcResult();
        String exceptionMessage = "failed";
        rpcResult.setException(new RuntimeException(exceptionMessage));

        Response response = new Response();
        response.setResult(rpcResult);
        response.setId(request.getId());
        ChannelBuffer bos = ChannelBuffers.dynamicBuffer(1024);

        ThriftCodec.RequestData rd = ThriftCodec.RequestData.create(
                ThriftCodec.getSeqId(), Demo.Iface.class.getName(), "echoString");
        ThriftCodec.cachedRequest.put(request.getId(), rd);
        codec.encode(channel, bos, response);

        byte[] buf = new byte[bos.writerIndex() - 4];
        System.arraycopy(bos.array(), 4, buf, 0, bos.writerIndex() - 4);
        ByteArrayInputStream bis = new ByteArrayInputStream(buf);

        if (bis.markSupported()) {
            bis.mark(0);
        }

        TIOStreamTransport transport = new TIOStreamTransport(bis);
        TBinaryProtocol protocol = new TBinaryProtocol(transport);

        Assert.assertEquals(ThriftCodec.MAGIC, protocol.readI16());
        Assert.assertEquals(protocol.readI32() + 4, bos.writerIndex());
        int headerLength = protocol.readI16();

        Assert.assertEquals(ThriftCodec.VERSION, protocol.readByte());
        Assert.assertEquals(Demo.Iface.class.getName(), protocol.readString());
        Assert.assertEquals(request.getId(), protocol.readI64());

        if (bis.markSupported()) {
            bis.reset();
            bis.skip(headerLength);
        }

        TMessage message = protocol.readMessageBegin();
        Assert.assertEquals("echoString", message.name);
        Assert.assertEquals(TMessageType.EXCEPTION, message.type);
        Assert.assertEquals(ThriftCodec.getSeqId(), message.seqid);
        TApplicationException exception = TApplicationException.read(protocol);
        protocol.readMessageEnd();

        Assert.assertEquals(exceptionMessage, exception.getMessage());

    }

    @Test
    public void testDecodeRequest() throws Exception {
        Request request = createRequest();
        // encode
        RandomAccessByteArrayOutputStream bos = new RandomAccessByteArrayOutputStream(1024);

        TIOStreamTransport transport = new TIOStreamTransport(bos);

        TBinaryProtocol protocol = new TBinaryProtocol(transport);

        int messageLength, headerLength;

        protocol.writeI16(ThriftCodec.MAGIC);
        protocol.writeI32(Integer.MAX_VALUE);
        protocol.writeI16(Short.MAX_VALUE);
        protocol.writeByte(ThriftCodec.VERSION);
        protocol.writeString(
                ((RpcInvocation) request.getData())
                        .getAttachment(Constants.INTERFACE_KEY));
        protocol.writeI64(request.getId());
        protocol.getTransport().flush();
        headerLength = bos.size();

        Demo.echoString_args args = new Demo.echoString_args();
        args.setArg("Hell, World!");

        TMessage message = new TMessage("echoString", TMessageType.CALL, ThriftCodec.getSeqId());

        protocol.writeMessageBegin(message);
        args.write(protocol);
        protocol.writeMessageEnd();
        protocol.getTransport().flush();
        int oldIndex = messageLength = bos.size();

        try {
            bos.setWriteIndex(ThriftCodec.MESSAGE_HEADER_LENGTH_INDEX);
            protocol.writeI16((short) (0xffff & headerLength));
            bos.setWriteIndex(ThriftCodec.MESSAGE_LENGTH_INDEX);
            protocol.writeI32(messageLength);
        } finally {
            bos.setWriteIndex(oldIndex);
        }

        Object obj = codec.decode((Channel) null, ChannelBuffers.wrappedBuffer(
                encodeFrame(bos.toByteArray())));

        Assert.assertTrue(obj instanceof Request);

        obj = ((Request) obj).getData();

        Assert.assertTrue(obj instanceof RpcInvocation);

        RpcInvocation invocation = (RpcInvocation) obj;

        Assert.assertEquals("echoString", invocation.getMethodName());
        Assert.assertArrayEquals(new Class[]{String.class}, invocation.getParameterTypes());
        Assert.assertArrayEquals(new Object[]{args.getArg()}, invocation.getArguments());

    }

    private Request createRequest() {

        RpcInvocation invocation = new RpcInvocation();

        invocation.setMethodName("echoString");

        invocation.setArguments(new Object[]{"Hello, World!"});

        invocation.setParameterTypes(new Class<?>[]{String.class});

        invocation.setAttachment(Constants.INTERFACE_KEY, Demo.Iface.class.getName());

        Request request = new Request(1L);

        request.setData(invocation);

        return request;

    }

}