/*
 * The MIT License (MIT)
 * Copyright (c) 2018 Microsoft Corporation
 * 
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 * 
 * The above copyright notice and this permission notice shall be included in all
 * copies or substantial portions of the Software.
 * 
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 */
package com.microsoft.azure.cosmosdb.rx.internal;

import java.time.Duration;
import java.util.concurrent.TimeUnit;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.fail;
import org.mockito.Matchers;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

import com.microsoft.azure.cosmosdb.internal.Quadruple;
import com.microsoft.azure.cosmosdb.internal.directconnectivity.GoneException;
import com.microsoft.azure.cosmosdb.internal.directconnectivity.StoreResponse;
import com.microsoft.azure.cosmosdb.internal.directconnectivity.StoreResponseBuilder;
import com.microsoft.azure.cosmosdb.internal.directconnectivity.StoreResponseValidator;
import com.microsoft.azure.cosmosdb.rx.internal.IRetryPolicy.ShouldRetryResult;

import rx.Single;
import rx.functions.Func1;
import rx.observers.TestSubscriber;

public class RetryUtilsTest {
    IRetryPolicy retryPolicy;
    Func1<Quadruple<Boolean, Boolean, Duration, Integer>, Single<StoreResponse>> callbackMethod;
    Func1<Quadruple<Boolean, Boolean, Duration, Integer>, Single<StoreResponse>> inBackoffAlternateCallbackMethod;
    private static final Duration minBackoffForInBackoffCallback = Duration.ofMillis(10);
    private static final int TIMEOUT = 30000;
    private static final Duration BACK_OFF_DURATION = Duration.ofMillis(20);
    private StoreResponse storeResponse;

    @BeforeClass(groups = { "unit" })
    public void beforeClass() throws Exception {
        retryPolicy = Mockito.mock(IRetryPolicy.class);
        callbackMethod = Mockito.mock(Func1.class);
        inBackoffAlternateCallbackMethod = Mockito.mock(Func1.class);
        storeResponse = getStoreResponse();
    }

    /**
     * This method will make sure we are throwing original exception in case of
     * ShouldRetryResult.noRetry() instead of Single.error(null).
     */
    @Test(groups = { "unit" }, timeOut = TIMEOUT)
    public void toRetryWithAlternateFuncWithNoRetry() {
        Func1<Throwable, Single<StoreResponse>> onErrorFunc = RetryUtils.toRetryWithAlternateFunc(callbackMethod,
                retryPolicy, inBackoffAlternateCallbackMethod, minBackoffForInBackoffCallback);
        Mockito.when(retryPolicy.shouldRetry(Matchers.any())).thenReturn(Single.just(ShouldRetryResult.noRetry()));
        Single<StoreResponse> response = onErrorFunc.call(new GoneException());
        validateFailure(response, TIMEOUT, GoneException.class);
    }

    /**
     * This method will test retries on callbackMethod, eventually returning success
     * response after some failures and making sure it failed for at least specific
     * number before passing.
     */
    @Test(groups = { "unit" }, timeOut = TIMEOUT)
    public void toRetryWithAlternateFuncTestingMethodOne() {
        Func1<Throwable, Single<StoreResponse>> onErrorFunc = RetryUtils.toRetryWithAlternateFunc(callbackMethod,
                retryPolicy, null, minBackoffForInBackoffCallback);

        toggleMockFuncBtwFailureSuccess(callbackMethod);
        Mockito.when(retryPolicy.shouldRetry(Matchers.any()))
                .thenReturn(Single.just(ShouldRetryResult.retryAfter(BACK_OFF_DURATION)));
        Single<StoreResponse> response = onErrorFunc.call(new GoneException());
        StoreResponseValidator validator = StoreResponseValidator.create().withStatus(storeResponse.getStatus())
                .withContent(storeResponse.getResponseBody()).build();
        validateSuccess(response, validator, TIMEOUT);
        Mockito.verify(callbackMethod, Mockito.times(4)).call(Matchers.any());
    }

    /**
     * This method will test retries on inBackoffAlternateCallbackMethod, eventually
     * returning success response after some failures and making sure it failed for
     * at least specific number before passing.
     */
    @Test(groups = { "unit" }, timeOut = TIMEOUT)
    public void toRetryWithAlternateFuncTestingMethodTwo() {
        Func1<Throwable, Single<StoreResponse>> onErrorFunc = RetryUtils.toRetryWithAlternateFunc(callbackMethod,
                retryPolicy, inBackoffAlternateCallbackMethod, minBackoffForInBackoffCallback);
        Mockito.when(callbackMethod.call(Matchers.any())).thenReturn(Single.error(new GoneException()));
        toggleMockFuncBtwFailureSuccess(inBackoffAlternateCallbackMethod);
        Mockito.when(retryPolicy.shouldRetry(Matchers.any()))
                .thenReturn(Single.just(ShouldRetryResult.retryAfter(BACK_OFF_DURATION)));
        Single<StoreResponse> response = onErrorFunc.call(new GoneException());
        StoreResponseValidator validator = StoreResponseValidator.create().withStatus(storeResponse.getStatus())
                .withContent(storeResponse.getResponseBody()).build();
        validateSuccess(response, validator, TIMEOUT);
        Mockito.verify(inBackoffAlternateCallbackMethod, Mockito.times(4)).call(Matchers.any());
    }

    private void validateFailure(Single<StoreResponse> single, long timeout, Class<? extends Throwable> class1) {

        TestSubscriber<StoreResponse> testSubscriber = new TestSubscriber<>();
        single.toObservable().subscribe(testSubscriber);
        testSubscriber.awaitTerminalEvent(timeout, TimeUnit.MILLISECONDS);
        testSubscriber.assertNotCompleted();
        testSubscriber.assertTerminalEvent();
        assertThat(testSubscriber.getOnErrorEvents()).hasSize(1);
        if (!(testSubscriber.getOnErrorEvents().get(0).getClass().equals(class1))) {
            fail("Not expecting " + testSubscriber.getOnErrorEvents().get(0));
        }
    }

    private void validateSuccess(Single<StoreResponse> single, StoreResponseValidator validator, long timeout) {

        TestSubscriber<StoreResponse> testSubscriber = new TestSubscriber<>();
        single.toObservable().subscribe(testSubscriber);
        testSubscriber.awaitTerminalEvent(timeout, TimeUnit.MILLISECONDS);
        assertThat(testSubscriber.getOnNextEvents()).hasSize(1);
        validator.validate(testSubscriber.getOnNextEvents().get(0));
    }

    private void toggleMockFuncBtwFailureSuccess(
            Func1<Quadruple<Boolean, Boolean, Duration, Integer>, Single<StoreResponse>> method) {
        Mockito.when(method.call(Matchers.any())).thenAnswer(new Answer<Single<StoreResponse>>() {

            private int count = 0;

            @Override
            public Single<StoreResponse> answer(InvocationOnMock invocation) throws Throwable {
                if (count++ < 3) {
                    return Single.error(new GoneException());
                }
                return Single.just(storeResponse);
            }
        });
    }

    private StoreResponse getStoreResponse() {
        StoreResponseBuilder storeResponseBuilder = new StoreResponseBuilder().withContent("Test content")
                .withStatus(200);
        return storeResponseBuilder.build();
    }
}