/**    when(p1.getKMSUrl()).thenReturn("p1");
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.hadoop.crypto.key.kms;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import java.io.IOException;
import java.net.URI;
import java.security.NoSuchAlgorithmException;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.crypto.key.KeyProvider;
import org.apache.hadoop.crypto.key.KeyProvider.Options;
import org.junit.Test;
import org.mockito.Mockito;

import com.google.common.collect.Sets;

public class TestLoadBalancingKMSClientProvider {

  @Test
  public void testCreation() throws Exception {
    Configuration conf = new Configuration();
    KeyProvider kp = new KMSClientProvider.Factory().createProvider(new URI(
        "kms://http@host1/kms/foo"), conf);
    assertTrue(kp instanceof KMSClientProvider);
    assertEquals("http://host1/kms/foo/v1/",
        ((KMSClientProvider) kp).getKMSUrl());

    kp = new KMSClientProvider.Factory().createProvider(new URI(
        "kms://http@host1;host2;host3/kms/foo"), conf);
    assertTrue(kp instanceof LoadBalancingKMSClientProvider);
    KMSClientProvider[] providers =
        ((LoadBalancingKMSClientProvider) kp).getProviders();
    assertEquals(3, providers.length);
    assertEquals(Sets.newHashSet("http://host1/kms/foo/v1/",
        "http://host2/kms/foo/v1/",
        "http://host3/kms/foo/v1/"),
        Sets.newHashSet(providers[0].getKMSUrl(),
            providers[1].getKMSUrl(),
            providers[2].getKMSUrl()));

    kp = new KMSClientProvider.Factory().createProvider(new URI(
        "kms://http@host1;host2;host3:16000/kms/foo"), conf);
    assertTrue(kp instanceof LoadBalancingKMSClientProvider);
    providers =
        ((LoadBalancingKMSClientProvider) kp).getProviders();
    assertEquals(3, providers.length);
    assertEquals(Sets.newHashSet("http://host1:16000/kms/foo/v1/",
        "http://host2:16000/kms/foo/v1/",
        "http://host3:16000/kms/foo/v1/"),
        Sets.newHashSet(providers[0].getKMSUrl(),
            providers[1].getKMSUrl(),
            providers[2].getKMSUrl()));
  }

  @Test
  public void testLoadBalancing() throws Exception {
    Configuration conf = new Configuration();
    KMSClientProvider p1 = mock(KMSClientProvider.class);
    when(p1.createKey(Mockito.anyString(), Mockito.any(Options.class)))
        .thenReturn(
            new KMSClientProvider.KMSKeyVersion("p1", "v1", new byte[0]));
    KMSClientProvider p2 = mock(KMSClientProvider.class);
    when(p2.createKey(Mockito.anyString(), Mockito.any(Options.class)))
        .thenReturn(
            new KMSClientProvider.KMSKeyVersion("p2", "v2", new byte[0]));
    KMSClientProvider p3 = mock(KMSClientProvider.class);
    when(p3.createKey(Mockito.anyString(), Mockito.any(Options.class)))
        .thenReturn(
            new KMSClientProvider.KMSKeyVersion("p3", "v3", new byte[0]));
    KeyProvider kp = new LoadBalancingKMSClientProvider(
        new KMSClientProvider[] { p1, p2, p3 }, 0, conf);
    assertEquals("p1", kp.createKey("test1", new Options(conf)).getName());
    assertEquals("p2", kp.createKey("test2", new Options(conf)).getName());
    assertEquals("p3", kp.createKey("test3", new Options(conf)).getName());
    assertEquals("p1", kp.createKey("test4", new Options(conf)).getName());
  }

  @Test
  public void testLoadBalancingWithFailure() throws Exception {
    Configuration conf = new Configuration();
    KMSClientProvider p1 = mock(KMSClientProvider.class);
    when(p1.createKey(Mockito.anyString(), Mockito.any(Options.class)))
        .thenReturn(
            new KMSClientProvider.KMSKeyVersion("p1", "v1", new byte[0]));
    when(p1.getKMSUrl()).thenReturn("p1");
    // This should not be retried
    KMSClientProvider p2 = mock(KMSClientProvider.class);
    when(p2.createKey(Mockito.anyString(), Mockito.any(Options.class)))
        .thenThrow(new NoSuchAlgorithmException("p2"));
    when(p2.getKMSUrl()).thenReturn("p2");
    KMSClientProvider p3 = mock(KMSClientProvider.class);
    when(p3.createKey(Mockito.anyString(), Mockito.any(Options.class)))
        .thenReturn(
            new KMSClientProvider.KMSKeyVersion("p3", "v3", new byte[0]));
    when(p3.getKMSUrl()).thenReturn("p3");
    // This should be retried
    KMSClientProvider p4 = mock(KMSClientProvider.class);
    when(p4.createKey(Mockito.anyString(), Mockito.any(Options.class)))
        .thenThrow(new IOException("p4"));
    when(p4.getKMSUrl()).thenReturn("p4");
    KeyProvider kp = new LoadBalancingKMSClientProvider(
        new KMSClientProvider[] { p1, p2, p3, p4 }, 0, conf);

    assertEquals("p1", kp.createKey("test4", new Options(conf)).getName());
    // Exceptions other than IOExceptions will not be retried
    try {
      kp.createKey("test1", new Options(conf)).getName();
      fail("Should fail since its not an IOException");
    } catch (Exception e) {
      assertTrue(e instanceof NoSuchAlgorithmException);
    }
    assertEquals("p3", kp.createKey("test2", new Options(conf)).getName());
    // IOException will trigger retry in next provider
    assertEquals("p1", kp.createKey("test3", new Options(conf)).getName());
  }

  @Test
  public void testLoadBalancingWithAllBadNodes() throws Exception {
    Configuration conf = new Configuration();
    KMSClientProvider p1 = mock(KMSClientProvider.class);
    when(p1.createKey(Mockito.anyString(), Mockito.any(Options.class)))
        .thenThrow(new IOException("p1"));
    KMSClientProvider p2 = mock(KMSClientProvider.class);
    when(p2.createKey(Mockito.anyString(), Mockito.any(Options.class)))
        .thenThrow(new IOException("p2"));
    KMSClientProvider p3 = mock(KMSClientProvider.class);
    when(p3.createKey(Mockito.anyString(), Mockito.any(Options.class)))
        .thenThrow(new IOException("p3"));
    KMSClientProvider p4 = mock(KMSClientProvider.class);
    when(p4.createKey(Mockito.anyString(), Mockito.any(Options.class)))
        .thenThrow(new IOException("p4"));
    when(p1.getKMSUrl()).thenReturn("p1");
    when(p2.getKMSUrl()).thenReturn("p2");
    when(p3.getKMSUrl()).thenReturn("p3");
    when(p4.getKMSUrl()).thenReturn("p4");
    KeyProvider kp = new LoadBalancingKMSClientProvider(
        new KMSClientProvider[] { p1, p2, p3, p4 }, 0, conf);
    try {
      kp.createKey("test3", new Options(conf)).getName();
      fail("Should fail since all providers threw an IOException");
    } catch (Exception e) {
      assertTrue(e instanceof IOException);
    }
  }
}