/*
 * Copyright (C) 2012 Facebook, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.airlift.drift.server;

import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.inject.Injector;
import io.airlift.bootstrap.Bootstrap;
import io.airlift.bootstrap.LifeCycleManager;
import io.airlift.drift.TApplicationException;
import io.airlift.drift.TException;
import io.airlift.drift.annotations.ThriftException;
import io.airlift.drift.annotations.ThriftMethod;
import io.airlift.drift.annotations.ThriftService;
import io.airlift.drift.annotations.ThriftStruct;
import io.airlift.drift.codec.ThriftCodecManager;
import io.airlift.drift.server.TestingServerTransport.State;
import io.airlift.drift.server.stats.MethodInvocationStatsFactory;
import io.airlift.drift.transport.server.ServerTransportFactory;
import org.testng.annotations.Test;

import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.Supplier;

import static com.google.common.util.concurrent.Futures.getDone;
import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
import static com.google.inject.multibindings.OptionalBinder.newOptionalBinder;
import static io.airlift.drift.server.TestingInvocationTarget.combineTestingInvocationTarget;
import static io.airlift.drift.server.guice.DriftServerBinder.driftServerBinder;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertNotNull;
import static org.testng.Assert.assertSame;
import static org.testng.Assert.assertTrue;
import static org.testng.Assert.fail;

public class TestDriftServer
{
    @Test
    public void testInvoker()
            throws Exception
    {
        ResultsSupplier resultsSupplier = new ResultsSupplier();
        TestService testService = new TestService(resultsSupplier);
        TestingServerTransportFactory serverTransportFactory = new TestingServerTransportFactory();
        TestingMethodInvocationStatsFactory statsFactory = new TestingMethodInvocationStatsFactory();
        DriftServer driftServer = new DriftServer(
                serverTransportFactory,
                new ThriftCodecManager(),
                statsFactory,
                ImmutableSet.of(new DriftService(testService, Optional.empty(), true)),
                ImmutableSet.of());

        TestingServerTransport serverTransport = serverTransportFactory.getServerTransport();
        assertNotNull(serverTransport);
        assertEquals(serverTransport.getState(), State.NOT_STARTED);

        driftServer.start();
        assertEquals(serverTransport.getState(), State.RUNNING);

        testServer(resultsSupplier, testService, statsFactory, serverTransport);

        driftServer.shutdown();
        assertEquals(serverTransport.getState(), State.SHUTDOWN);
    }

    @Test
    public void testFilter()
            throws Exception
    {
        ResultsSupplier resultsSupplier = new ResultsSupplier();
        PassThroughFilter passThroughFilter = new PassThroughFilter();
        ShortCircuitFilter shortCircuitFilter = new ShortCircuitFilter(resultsSupplier);
        // test servers will not see the invocation
        TestService testService = new TestService(() -> Futures.immediateFailedFuture(new Exception("Should not be called")));

        TestingServerTransportFactory serverTransportFactory = new TestingServerTransportFactory();
        TestingMethodInvocationStatsFactory statsFactory = new TestingMethodInvocationStatsFactory();
        DriftServer driftServer = new DriftServer(
                serverTransportFactory,
                new ThriftCodecManager(),
                statsFactory,
                ImmutableSet.of(new DriftService(testService, Optional.empty(), true)),
                ImmutableSet.of(passThroughFilter, shortCircuitFilter));

        TestingServerTransport serverTransport = serverTransportFactory.getServerTransport();
        assertNotNull(serverTransport);
        assertEquals(serverTransport.getState(), State.NOT_STARTED);

        driftServer.start();
        assertEquals(serverTransport.getState(), State.RUNNING);

        testServer(resultsSupplier, combineTestingInvocationTarget(passThroughFilter, shortCircuitFilter), statsFactory, serverTransport);

        driftServer.shutdown();
        assertEquals(serverTransport.getState(), State.SHUTDOWN);
    }

    @Test
    public void testGuiceServer()
    {
        ResultsSupplier resultsSupplier = new ResultsSupplier();
        TestService testService = new TestService(resultsSupplier);

        TestingServerTransportFactory serverTransportFactory = new TestingServerTransportFactory();
        TestingMethodInvocationStatsFactory statsFactory = new TestingMethodInvocationStatsFactory();

        Bootstrap app = new Bootstrap(
                binder -> binder.bind(TestService.class).toInstance(testService),
                binder -> driftServerBinder(binder).bindService(TestService.class),
                binder -> binder.bind(ServerTransportFactory.class).toInstance(serverTransportFactory),
                binder -> newOptionalBinder(binder, MethodInvocationStatsFactory.class)
                        .setBinding()
                        .toInstance(statsFactory));

        LifeCycleManager lifeCycleManager = null;
        try {
            Injector injector = app
                    .strictConfig()
                    .doNotInitializeLogging()
                    .initialize();
            lifeCycleManager = injector.getInstance(LifeCycleManager.class);

            assertEquals(serverTransportFactory.getServerTransport().getState(), State.RUNNING);

            testServer(resultsSupplier, testService, statsFactory, serverTransportFactory.getServerTransport());
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        finally {
            if (lifeCycleManager != null) {
                try {
                    lifeCycleManager.stop();
                }
                catch (Exception ignored) {
                }
            }
        }

        assertEquals(serverTransportFactory.getServerTransport().getState(), State.SHUTDOWN);
    }

    @Test
    public void testGuiceServerFilter()
    {
        ResultsSupplier resultsSupplier = new ResultsSupplier();
        PassThroughFilter passThroughFilter = new PassThroughFilter();
        ShortCircuitFilter shortCircuitFilter = new ShortCircuitFilter(resultsSupplier);
        // test servers will not see the invocation
        TestService testService = new TestService(() -> Futures.immediateFailedFuture(new Exception("Should not be called")));

        TestingServerTransportFactory serverTransportFactory = new TestingServerTransportFactory();
        TestingMethodInvocationStatsFactory statsFactory = new TestingMethodInvocationStatsFactory();

        Bootstrap app = new Bootstrap(
                binder -> binder.bind(TestService.class).toInstance(testService),
                binder -> driftServerBinder(binder).bindService(TestService.class),
                binder -> driftServerBinder(binder).bindFilter(passThroughFilter),
                binder -> driftServerBinder(binder).bindFilter(shortCircuitFilter),
                binder -> binder.bind(ServerTransportFactory.class).toInstance(serverTransportFactory),
                binder -> newOptionalBinder(binder, MethodInvocationStatsFactory.class)
                        .setBinding()
                        .toInstance(statsFactory));

        LifeCycleManager lifeCycleManager = null;
        try {
            Injector injector = app
                    .strictConfig()
                    .doNotInitializeLogging()
                    .initialize();
            lifeCycleManager = injector.getInstance(LifeCycleManager.class);

            assertEquals(serverTransportFactory.getServerTransport().getState(), State.RUNNING);

            testServer(resultsSupplier, combineTestingInvocationTarget(passThroughFilter, shortCircuitFilter), statsFactory, serverTransportFactory.getServerTransport());
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        finally {
            if (lifeCycleManager != null) {
                try {
                    lifeCycleManager.stop();
                }
                catch (Exception ignored) {
                }
            }
        }

        assertEquals(serverTransportFactory.getServerTransport().getState(), State.SHUTDOWN);
    }

    private static void testServer(
            ResultsSupplier resultsSupplier,
            TestingInvocationTarget invocationTarget,
            TestingMethodInvocationStatsFactory statsFactory,
            TestingServerTransport serverTransport)
            throws ExecutionException
    {
        // test normal invocation
        assertNormalInvocation(resultsSupplier, serverTransport, invocationTarget, statsFactory, Optional.empty());

        // test method throws TException
        assertExceptionInvocation(resultsSupplier, serverTransport, invocationTarget, statsFactory, Optional.empty(), new TestServiceException());
        assertExceptionInvocation(resultsSupplier, serverTransport, invocationTarget, statsFactory, Optional.empty(), new TException());
        assertExceptionInvocation(resultsSupplier, serverTransport, invocationTarget, statsFactory, Optional.empty(), new TApplicationException());
        assertExceptionInvocation(resultsSupplier, serverTransport, invocationTarget, statsFactory, Optional.empty(), new RuntimeException());
        assertExceptionInvocation(resultsSupplier, serverTransport, invocationTarget, statsFactory, Optional.empty(), new Error());

        // custom exception subclasses
        assertExceptionInvocation(resultsSupplier, serverTransport, invocationTarget, statsFactory, Optional.empty(), new TestServiceException() {});
    }

    private static void assertNormalInvocation(
            ResultsSupplier resultsSupplier,
            TestingServerTransport serverTransport,
            TestingInvocationTarget invocationTarget,
            TestingMethodInvocationStatsFactory statsFactory,
            Optional<String> qualifier)
            throws ExecutionException
    {
        TestingMethodInvocationStat testStat = statsFactory.getStat("serverService", qualifier, "test");
        testStat.clear();
        int invocationId = ThreadLocalRandom.current().nextInt();
        String expectedResult = "result " + invocationId;
        resultsSupplier.setSuccessResult(expectedResult);
        ListenableFuture<Object> result = serverTransport.invoke("test", ImmutableMap.of(), ImmutableMap.of((short) 1, invocationId, (short) 2, "normal"));
        assertTrue(result.isDone());
        assertEquals(getDone(result), expectedResult);
        invocationTarget.assertInvocation("test", invocationId, "normal");
        testStat.assertSuccess();

        TestingMethodInvocationStat testAsyncStat = statsFactory.getStat("serverService", qualifier, "testAsync");
        testAsyncStat.clear();
        invocationId = ThreadLocalRandom.current().nextInt();
        expectedResult = "async " + expectedResult;
        resultsSupplier.setSuccessResult(expectedResult);
        ListenableFuture<Object> asyncResult = serverTransport.invoke("testAsync", ImmutableMap.of(), ImmutableMap.of((short) 1, invocationId, (short) 2, "async"));
        assertTrue(asyncResult.isDone());
        assertEquals(getDone(asyncResult), expectedResult);
        invocationTarget.assertInvocation("testAsync", invocationId, "async");
        testAsyncStat.assertSuccess();
    }

    private static void assertExceptionInvocation(
            ResultsSupplier resultsSupplier,
            TestingServerTransport serverTransport,
            TestingInvocationTarget invocationTarget,
            TestingMethodInvocationStatsFactory statsFactory,
            Optional<String> qualifier,
            Throwable testException)
    {
        String name = "exception-" + testException.getClass().getName();

        TestingMethodInvocationStat testStat = statsFactory.getStat("serverService", qualifier, "test");
        testStat.clear();
        int invocationId = ThreadLocalRandom.current().nextInt();
        resultsSupplier.setFailedResult(testException);
        ListenableFuture<Object> result = serverTransport.invoke("test", ImmutableMap.of(), ImmutableMap.of((short) 1, invocationId, (short) 2, name));
        assertTrue(result.isDone());
        try {
            getDone(result);
            fail("expected exception");
        }
        catch (ExecutionException e) {
            assertSame(e.getCause(), testException);
        }
        invocationTarget.assertInvocation("test", invocationId, name);
        testStat.assertFailure();

        name = "async " + name;
        TestingMethodInvocationStat testAsyncStat = statsFactory.getStat("serverService", qualifier, "testAsync");
        testAsyncStat.clear();
        invocationId = ThreadLocalRandom.current().nextInt();
        resultsSupplier.setFailedResult(testException);
        ListenableFuture<Object> asyncResult = serverTransport.invoke("testAsync", ImmutableMap.of(), ImmutableMap.of((short) 1, invocationId, (short) 2, name));
        assertTrue(asyncResult.isDone());
        try {
            getDone(result);
            fail("expected exception");
        }
        catch (ExecutionException e) {
            assertSame(e.getCause(), testException);
        }
        invocationTarget.assertInvocation("testAsync", invocationId, name);
        testAsyncStat.assertFailure();
    }

    @ThriftService("serverService")
    public static class TestService
            implements TestingInvocationTarget
    {
        private final Supplier<ListenableFuture<Object>> resultsSupplier;

        private String methodName;
        private int id;
        private String name;

        public TestService(Supplier<ListenableFuture<Object>> resultsSupplier)
        {
            this.resultsSupplier = resultsSupplier;
        }

        @ThriftMethod
        public String test(int id, String name)
                throws TestServiceException, TException
        {
            this.methodName = "test";
            this.id = id;
            this.name = name;

            try {
                return (String) getDone(resultsSupplier.get());
            }
            catch (ExecutionException e) {
                Throwable failureResult = e.getCause();
                Throwables.propagateIfPossible(failureResult, TestServiceException.class);
                Throwables.propagateIfPossible(failureResult, TException.class);
                throw new RuntimeException(failureResult);
            }
        }

        @ThriftMethod(exception = @ThriftException(id = 0, type = TestServiceException.class))
        public ListenableFuture<String> testAsync(int id, String name)
        {
            this.methodName = "testAsync";
            this.id = id;
            this.name = name;
            return Futures.transform(resultsSupplier.get(), String::valueOf, directExecutor());
        }

        @Override
        public void assertInvocation(String expectedMethodName, int expectedId, String expectedName)
        {
            assertEquals(methodName, expectedMethodName);
            assertEquals(id, expectedId);
            assertEquals(name, expectedName);
        }
    }

    @ThriftStruct("testService")
    public static class TestServiceException
            extends Exception
    {
    }
}