/*****************************************************************************
 * ------------------------------------------------------------------------- *
 * Licensed under the Apache License, Version 2.0 (the "License");           *
 * you may not use this file except in compliance with the License.          *
 * You may obtain a copy of the License at                                   *
 *                                                                           *
 * http://www.apache.org/licenses/LICENSE-2.0                                *
 *                                                                           *
 * Unless required by applicable law or agreed to in writing, software       *
 * distributed under the License is distributed on an "AS IS" BASIS,         *
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  *
 * See the License for the specific language governing permissions and       *
 * limitations under the License.                                            *
 *****************************************************************************/
package com.google.mu.util.concurrent;

import static com.google.common.truth.Truth.assertThat;
import static com.google.mu.util.concurrent.Parallelizer.forAll;
import static java.util.Arrays.asList;
import static org.junit.Assert.fail;
import static org.junit.Assume.assumeFalse;
import static org.junit.Assume.assumeTrue;
import static org.junit.jupiter.api.Assertions.assertThrows;

import java.time.Duration;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.Verifier;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameters;

import com.google.common.truth.IterableSubject;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.mu.util.concurrent.Parallelizer.UncheckedExecutionException;

@RunWith(Parameterized.class)
public class ParallelizerTest {

  private final Mode mode;
  private final Threading threading;
  private ExecutorService threadPool;
  private volatile int maxInFlight = 3;
  private Duration timeout = Duration.ofMillis(100);
  private final AtomicInteger activeThreads = new AtomicInteger();
  private final ConcurrentMap<Integer, String> translated = new ConcurrentHashMap<>();
  private final ConcurrentLinkedQueue<Throwable> thrown = new ConcurrentLinkedQueue<>();
  private final ConcurrentLinkedQueue<Object> interrupted = new ConcurrentLinkedQueue<>();

  @Rule public final Verifier verifyTaskAssertions = new Verifier() {
    @Override protected void verify() throws Throwable {
      ParallelizerTest.this.shutdownThreadPool();
      for (Throwable e : thrown) {
        throw e;
      }
      assertThat(activeThreads.get()).isEqualTo(0);
    }
  };

  public ParallelizerTest(Mode mode, Threading threading) {
    this.mode = mode;
    this.threading = threading;
  }

  @Before public void initializeThreadPool() {
    threadPool = threading.newExecutorService();
  }

  @Test public void testOneInFlight() throws Exception {
    maxInFlight = 1;
    List<Integer> numbers = asList(1, 2, 3, 4, 5, 6, 7, 8, 9);
    parallelize(numbers.stream(), this::translateToString);
    assertThat(translated).containsExactlyEntriesIn(mapToString(numbers));
  }

  @Test public void testFastTasks() throws Exception {
    List<Integer> numbers = asList(1, 2, 3, 4, 5, 6, 7, 8, 9);
    parallelize(numbers.stream(), this::translateToString);
    assertThat(translated).containsExactlyEntriesIn(mapToString(numbers));
  }

  @Test public void testSlowTasks() throws Exception {
    List<Integer> numbers = asList(1, 2, 3, 4, 5);
    parallelize(numbers.stream(), delayed(Duration.ofMillis(2), this::translateToString));
    assertThat(translated).containsExactlyEntriesIn(mapToString(numbers));
  }

  @Test public void testLargeMaxInFlight() throws Exception {
    maxInFlight = Integer.MAX_VALUE;
    List<Integer> numbers = asList(1, 2, 3);
    parallelize(numbers.stream(), this::translateToString);
    assertThat(translated).containsExactlyEntriesIn(mapToString(numbers));
  }

  @Test public void testTaskExceptionDismissesPendingTasks() {
    maxInFlight = 2;
    UncheckedExecutionException exception = assertThrows(
        UncheckedExecutionException.class,
        () -> parallelize(Stream.of(
            // With maxInflight=2, at least one will print, even if a fail() task races it.
            () -> translateToString(1), () -> translateToString(1),
            () -> fail("foobar"), () -> fail("foobar"),  // both should fail
            () -> translateToString(5))));  // should be dismissed
    assertThat(exception.getCause().getMessage()).contains("foobar");
    assertThat(translated).containsEntry(1, "1");
    assertThat(translated).doesNotContainKey(5);
  }

  @Test public void testTaskExceptionCancelsInFlightTasks() throws InterruptedException {
    assumeFalse(threading == Threading.DIRECT);
    maxInFlight = 2;
    UncheckedExecutionException exception = assertThrows(
        UncheckedExecutionException.class,
        () -> parallelize(serialTasks(
            () -> translateToString(1),  // should print
            () -> blockFor(2), // Will be interrupted
            () -> fail("foobar"),  // kills the pipeline
            () -> translateToString(4))));  // should be dismissed
    assertThat(exception.getCause().getMessage()).contains("foobar");
    shutdownAndAssertInterruptedKeys().containsExactly(2);
    assertThat(translated).containsEntry(1, "1");
    assertThat(translated).doesNotContainKey(4);
  }

  @Test public void testSubmissionTimeoutCancelsInFlightTasks() throws InterruptedException {
    assumeFalse(threading == Threading.DIRECT);
    assumeTrue(mode == Mode.INTERRUPTIBLY);
    maxInFlight = 2;
    timeout = Duration.ofMillis(1);
    assertThrows(
        TimeoutException.class,
        () -> parallelize(serialTasks(
            () -> blockFor(1), // Will be interrupted
            () -> blockFor(2), // Will be interrupted
            () -> translateToString(3))));  // Times out
    shutdownAndAssertInterruptedKeys().containsExactly(1, 2);
    assertThat(translated).doesNotContainKey(3);
  }

  @Test public void testAwaitTimeoutCancelsInFlightTasks() throws InterruptedException {
    assumeFalse(threading == Threading.DIRECT);
    assumeTrue(mode == Mode.INTERRUPTIBLY);
    maxInFlight = 2;
    timeout = Duration.ofMillis(1);
    assertThrows(
        TimeoutException.class,
        () -> parallelize(serialTasks(
            () -> blockFor(1), // Will be interrupted
            () -> blockFor(2)))); // Might be interrupted
    shutdownAndAssertInterruptedKeys().contains(1);
  }

  @Test public void testUninterruptible() throws InterruptedException {
    assumeFalse(threading == Threading.DIRECT);
    assumeTrue(mode == Mode.UNINTERRUPTIBLY);
    maxInFlight = 2;
    List<Integer> numbers = asList(1, 2, 3, 4, 5);
    CountDownLatch allowTranslation = new CountDownLatch(1);
    Thread thread = new Thread(() -> {
      try {
        parallelize(numbers.stream(), input -> {
          try {
            allowTranslation.await();
          } catch (InterruptedException e) {
            thrown.add(e);
            return;
          }
          translateToString(input);
        });
      } catch (InterruptedException | TimeoutException impossible) {
        thrown.add(impossible);
      }
    });
    thread.start();
    thread.interrupt();
    assertThat(translated).isEmpty();
    allowTranslation.countDown();
    thread.join();
    // Even interrupted, all numbers should be printed.
    assertThat(translated).containsExactlyEntriesIn(mapToString(numbers));
  }

  @Test public void testInterruptible() throws InterruptedException {
    assumeFalse(threading == Threading.DIRECT);
    assumeTrue(mode == Mode.INTERRUPTIBLY);
    maxInFlight = 2;
    List<Integer> numbers = asList(1, 2, 3, 4);
    CountDownLatch inflight = new CountDownLatch(maxInFlight);
    CountDownLatch allowTranslation = new CountDownLatch(1);
    AtomicBoolean paralllelizationInterrupted = new AtomicBoolean();
    Thread thread = new Thread(() -> {
      try {
        parallelize(numbers.stream(), input -> {
          inflight.countDown();
          try {
            allowTranslation.await();
          } catch (InterruptedException e) {
            return;
          }
          translateToString(input);
        });
      } catch (InterruptedException expected) {
        paralllelizationInterrupted.set(true);
      } catch (TimeoutException e) {
        thrown.add(e);
      }
    });
    thread.start();
    inflight.await();
    thread.interrupt();
    assertThat(translated).isEmpty();
    allowTranslation.countDown();
    thread.join();
    // Only numbers already inflight are translated.
    assertThat(translated).doesNotContainKey(3);
    assertThat(translated).doesNotContainKey(4);
    assertThat(paralllelizationInterrupted.get()).isTrue();
  }

  @Test public void testErrorPropagated() {
    Error error = new Error();
    UncheckedExecutionException exception = assertThrows(
        UncheckedExecutionException.class,
        () -> parallelize(Stream.of(() -> raise(error))));
    assertThat(exception.getCause()).isSameAs(error);
  }

  @Test public void testExceptionPropagated() {
    RuntimeException exception = new RuntimeException();
    UncheckedExecutionException caught = assertThrows(
        UncheckedExecutionException.class,
        () -> parallelize(Stream.of(() -> raise(exception))));
    assertThat(caught.getCause()).isSameAs(exception);
  }

  private void translateToString(int i) {
    translated.put(i, Integer.toString(i));
  }

  private static <K> Map<K, String> mapToString(Collection<K> keys) {
    return keys.stream().collect(Collectors.toMap(k -> k, Object::toString));
  }

  /** Keeps track of active threads and makes sure it doesn't exceed {@link #maxInFlight}. */
  private void runTask(Runnable task) {
    try {
      try {
        assertThat(activeThreads.incrementAndGet()).isAtMost(maxInFlight);
      } catch (Throwable e) {
        thrown.add(e);
        return;
      }
      task.run();
    } finally {
      activeThreads.decrementAndGet();
    }
  }

  private <T> void parallelize(Stream<? extends T> inputs, Consumer<? super T> consumer)
      throws InterruptedException, TimeoutException {
    parallelize(forAll(inputs, consumer));
  }

  private <T> void parallelize(Stream<? extends Runnable> tasks)
      throws InterruptedException, TimeoutException {
    mode.run(
        new Parallelizer(threadPool, maxInFlight), tasks, this::runTask, timeout);
  }

  private void blockFor(Object key) {
    try {
      new CountDownLatch(1).await();
    } catch (InterruptedException e) {
      interrupted.add(key);
    }
  }

  // Returns a consumer that delegates to {@code consumer} after {@code delay}. */
  private static <T> Consumer<T> delayed(Duration delay, Consumer<T> consumer) {
    return input -> {
      try {
        Thread.sleep(delay.toMillis());
      } catch (InterruptedException e) {
        return;
      }
      consumer.accept(input);
    };
  }

  // Creates a task stream such that a task has to be started first before tasks after it can be
  // taken out of the stream. Helps to ensure in-flight status for tasks where we care.
  private static Stream<Runnable> serialTasks(Runnable... tasks) {
    Semaphore semaphore = new Semaphore(1);
    return asList(tasks).stream().map(task -> {
      semaphore.acquireUninterruptibly();
      return () -> {
        semaphore.release();
        task.run();
      };
    });
  }

  private IterableSubject shutdownAndAssertInterruptedKeys() throws InterruptedException {
    shutdownThreadPool();  // Allow left-over threads to respond to interruptions.
    return assertThat(interrupted);
  }

  private void shutdownThreadPool() throws InterruptedException {
    threadPool.shutdown();
    threadPool.awaitTermination(1, TimeUnit.SECONDS);
  }

  private static <E extends Throwable> void raise(E throwable) throws E {
    throw throwable;
  }

  @Parameters(name = "{index}: {0}/{1}")
  public static Object[][] data() {
    List<Object[]> groups = new ArrayList<>();
    for (Mode mode : Mode.values()) {
      for (Threading threading : Threading.values()) {
        groups.add(new Object[] {mode, threading});
      }
    }
    return groups.toArray(new Object[0][]);
  }

  private enum Mode {
    INTERRUPTIBLY {
      @Override <T> void run(
          Parallelizer parallelizer,
          Stream<? extends T> inputs, Consumer<? super T> consumer,
          Duration timeout)
          throws TimeoutException, InterruptedException {
        parallelizer.parallelize(inputs, consumer, timeout.toMillis(), TimeUnit.MILLISECONDS);
      }
    },
    UNINTERRUPTIBLY {
      @Override <T> void run(
          Parallelizer parallelizer,
          Stream<? extends T> inputs,
          Consumer<? super T> consumer,
          Duration timeout) {
        parallelizer.parallelizeUninterruptibly(inputs, consumer);
      }
    },
    INTERRUPTIBLY_FOR_ITERATOR {
      @Override <T> void run(
          Parallelizer parallelizer,
          Stream<? extends T> inputs, Consumer<? super T> consumer,
          Duration timeout)
          throws TimeoutException, InterruptedException {
        parallelizer.parallelize(
            inputs.iterator(), consumer, timeout.toMillis(), TimeUnit.MILLISECONDS);
      }
    },
    UNINTERRUPTIBLY_FOR_ITERATOR {
      @Override <T> void run(
          Parallelizer parallelizer,
          Stream<? extends T> inputs,
          Consumer<? super T> consumer,
          Duration timeout) {
        parallelizer.parallelizeUninterruptibly(inputs.iterator(), consumer);
      }
    },
    ;

    abstract <T> void run(
        Parallelizer parallelizer,
        Stream<? extends T> inputs,
        Consumer<? super T> consumer,
        Duration timeout)
        throws TimeoutException, InterruptedException;
  }

  private enum Threading {
    DIRECT {
      @Override ExecutorService newExecutorService() {
        return MoreExecutors.newDirectExecutorService();
      }
    },
    CACHED_THREAD_POOL {
      @Override ExecutorService newExecutorService() {
        return Executors.newCachedThreadPool();
      }
    },
    FIXED_THREAD_POOL {
      @Override ExecutorService newExecutorService() {
        return Executors.newFixedThreadPool(10);
      }
    },
    ;
    abstract ExecutorService newExecutorService();
  }
}