package com.auth0.android.request.internal;

import com.squareup.okhttp.ConnectionSpec;
import com.squareup.okhttp.Interceptor;
import com.squareup.okhttp.OkHttpClient;
import com.squareup.okhttp.Protocol;
import com.squareup.okhttp.TlsVersion;
import com.squareup.okhttp.logging.HttpLoggingInterceptor;

import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.robolectric.RobolectricTestRunner;
import org.robolectric.annotation.Config;

import java.util.List;

import javax.net.ssl.SSLSocketFactory;

import static junit.framework.Assert.assertTrue;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.notNullValue;
import static org.junit.Assert.assertThat;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

@RunWith(RobolectricTestRunner.class)
@Config(sdk = 21)
public class OkHttpClientFactoryTest {

    private OkHttpClientFactory factory;
    @Mock
    private OkHttpClient mockClient;

    @Before
    public void setUp() {
        MockitoAnnotations.initMocks(this);
        factory = new OkHttpClientFactory();
    }

    @Test
    // Verify that there's no error when creating a new OkHttpClient instance
    public void shouldCreateNewClient() {
        factory.createClient(false, false, 0, 0, 0);
    }

    @Test
    public void shouldNotUseHttp2Protocol() {
        OkHttpClient client = factory.createClient(false, false, 0, 0, 0);
        //Doesn't use default protocols
        assertThat(client.getProtocols(), is(notNullValue()));
        assertThat(client.getProtocols().contains(Protocol.HTTP_1_1), is(true));
        assertThat(client.getProtocols().contains(Protocol.SPDY_3), is(true));
        assertThat(client.getProtocols().contains(Protocol.HTTP_2), is(false));
    }

    @Test
    public void shouldUseDefaultTimeoutWhenTimeoutZero() {
        OkHttpClient client = factory.createClient(false, false, 0, 0, 0);
        assertThat(client.getConnectTimeout(), is(10000));
        assertThat(client.getReadTimeout(), is(10000));
        assertThat(client.getWriteTimeout(), is(10000));
    }

    @Test
    public void shouldUsePassedInTimeout() {
        OkHttpClient client = factory.createClient(false, false, 5, 15, 20);
        assertThat(client.getConnectTimeout(), is(5000));
        assertThat(client.getReadTimeout(), is(15000));
        assertThat(client.getWriteTimeout(), is(20000));
    }

    @Test
    @Config(sdk = 21)
    public void shouldEnableLoggingTLS12Enforced() {
        List list = generateInterceptorsMockList(mockClient);
        OkHttpClient client = factory.modifyClient(mockClient, true, true, 0, 0, 0);
        verifyLoggingEnabled(client, list);
        verifyTLS12Enforced(client);
    }

    @Test
    @Config(sdk = 21)
    public void shouldEnableLoggingTLS12NotEnforced() {
        List list = generateInterceptorsMockList(mockClient);
        OkHttpClient client = factory.modifyClient(mockClient, true, false, 0, 0, 0);
        verifyLoggingEnabled(client, list);
        verifyTLS12NotEnforced(client);
    }

    @Test
    @Config(sdk = 21)
    public void shouldDisableLoggingTLS12Enforced() {
        List list = generateInterceptorsMockList(mockClient);
        OkHttpClient client = factory.modifyClient(mockClient, false, true, 0, 0, 0);
        verifyLoggingDisabled(client, list);
        verifyTLS12Enforced(client);
    }

    @Test
    @Config(sdk = 21)
    public void shouldDisableLoggingTLS12NotEnforced() {
        List list = generateInterceptorsMockList(mockClient);
        OkHttpClient client = factory.modifyClient(mockClient, false, false, 0, 0, 0);
        verifyLoggingDisabled(client, list);
        verifyTLS12NotEnforced(client);
    }

    @Test
    @Config(sdk = 22)
    public void shouldEnableLoggingTLS12Enforced_postLollipopTLS12NoEffect() {
        List list = generateInterceptorsMockList(mockClient);
        OkHttpClient client = factory.modifyClient(mockClient, true, true, 0, 0, 0);
        verifyLoggingEnabled(client, list);
        verifyTLS12NotEnforced(client);
    }

    @Test
    @Config(sdk = 22)
    public void shouldEnableLoggingTLS12NotEnforced_posLollipop() {
        List list = generateInterceptorsMockList(mockClient);
        OkHttpClient client = factory.modifyClient(mockClient, true, false, 0, 0, 0);
        verifyLoggingEnabled(client, list);
        verifyTLS12NotEnforced(client);
    }

    @Test
    @Config(sdk = 22)
    public void shouldDisableLoggingTLS12Enforced_postLollipopTLS12NoEffect() {
        List list = generateInterceptorsMockList(mockClient);
        OkHttpClient client = factory.modifyClient(mockClient, false, true, 0, 0, 0);
        verifyLoggingDisabled(client, list);
        verifyTLS12NotEnforced(client);
    }

    @Test
    @Config(sdk = 22)
    public void shouldDisableLoggingTLS12NotEnforced_postLollipop() {
        List list = generateInterceptorsMockList(mockClient);
        OkHttpClient client = factory.modifyClient(mockClient, false, false, 0, 0, 0);
        verifyLoggingDisabled(client, list);
        verifyTLS12NotEnforced(client);
    }

    private static List generateInterceptorsMockList(OkHttpClient client) {
        List list = mock(List.class);
        when(client.interceptors()).thenReturn(list);
        return list;
    }

    private static void verifyLoggingEnabled(OkHttpClient client, List list) {
        verify(client).interceptors();

        ArgumentCaptor<Interceptor> interceptorCaptor = ArgumentCaptor.forClass(Interceptor.class);
        verify(list).add(interceptorCaptor.capture());

        assertThat(interceptorCaptor.getValue(), is(notNullValue()));
        assertThat(interceptorCaptor.getValue(), is(instanceOf(HttpLoggingInterceptor.class)));
        assertThat(((HttpLoggingInterceptor) interceptorCaptor.getValue()).getLevel(), is(HttpLoggingInterceptor.Level.BODY));
    }

    private static void verifyLoggingDisabled(OkHttpClient client, List list) {
        verify(client, never()).interceptors();
        verify(list, never()).add(any(Interceptor.class));
    }

    private static void verifyTLS12NotEnforced(OkHttpClient client) {
        verify(client, never()).setSslSocketFactory((SSLSocketFactory) any());
    }

    private static void verifyTLS12Enforced(OkHttpClient client) {

        ArgumentCaptor<SSLSocketFactory> factoryCaptor = ArgumentCaptor.forClass(SSLSocketFactory.class);
        verify(client).setSslSocketFactory(factoryCaptor.capture());
        assertTrue(factoryCaptor.getValue() instanceof TLS12SocketFactory);

        ArgumentCaptor<List> specCaptor = ArgumentCaptor.forClass(List.class);
        verify(client).setConnectionSpecs(specCaptor.capture());
        boolean hasTls12 = false;
        for (Object item : specCaptor.getValue()) {
            assertTrue(item instanceof ConnectionSpec);
            ConnectionSpec spec = (ConnectionSpec) item;
            if (!spec.isTls()) {
                continue;
            }
            List<TlsVersion> versions = spec.tlsVersions();
            for (TlsVersion version : versions) {
                if ("TLSv1.2".equals(version.javaName())) {
                    hasTls12 = true;
                    break;
                }
            }
        }
        assertTrue(hasTls12);
    }
}