/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.beam.runners.direct;

import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.isA;
import static org.hamcrest.Matchers.nullValue;
import static org.junit.Assert.assertThat;
import static org.mockito.Mockito.when;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.EnumSet;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.beam.runners.direct.CommittedResult.OutputType;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.WithKeys;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.MoreExecutors;
import org.hamcrest.Matchers;
import org.joda.time.Instant;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;

/** Tests for {@link DirectTransformExecutor}. */
@RunWith(JUnit4.class)
public class DirectTransformExecutorTest {
  @Rule public ExpectedException thrown = ExpectedException.none();
  private PCollection<String> created;

  private AppliedPTransform<?, ?, ?> createdProducer;
  private AppliedPTransform<?, ?, ?> downstreamProducer;

  private CountDownLatch evaluatorCompleted;

  private RegisteringCompletionCallback completionCallback;
  private TransformExecutorService transformEvaluationState;
  private BundleFactory bundleFactory;
  @Mock private DirectMetrics metrics;
  @Mock private EvaluationContext evaluationContext;
  @Mock private TransformEvaluatorRegistry registry;

  @Rule public TestPipeline p = TestPipeline.create().enableAbandonedNodeEnforcement(false);

  @Before
  public void setup() {
    MockitoAnnotations.initMocks(this);

    bundleFactory = ImmutableListBundleFactory.create();

    transformEvaluationState =
        TransformExecutorServices.parallel(MoreExecutors.newDirectExecutorService());

    evaluatorCompleted = new CountDownLatch(1);
    completionCallback = new RegisteringCompletionCallback(evaluatorCompleted);

    created = p.apply(Create.of("foo", "spam", "third"));
    PCollection<KV<Integer, String>> downstream = created.apply(WithKeys.of(3));

    DirectGraphs.performDirectOverrides(p);
    DirectGraph graph = DirectGraphs.getGraph(p);
    createdProducer = graph.getProducer(created);
    downstreamProducer = graph.getProducer(downstream);

    when(evaluationContext.getMetrics()).thenReturn(metrics);
  }

  @Test
  public void callWithNullInputBundleFinishesBundleAndCompletes() throws Exception {
    final TransformResult<Object> result = StepTransformResult.withoutHold(createdProducer).build();
    final AtomicBoolean finishCalled = new AtomicBoolean(false);
    TransformEvaluator<Object> evaluator =
        new TransformEvaluator<Object>() {
          @Override
          public void processElement(WindowedValue<Object> element) throws Exception {
            throw new IllegalArgumentException("Shouldn't be called");
          }

          @Override
          public TransformResult<Object> finishBundle() throws Exception {
            finishCalled.set(true);
            return result;
          }
        };

    when(registry.forApplication(createdProducer, null)).thenReturn(evaluator);

    DirectTransformExecutor<Object> executor =
        new DirectTransformExecutor<>(
            evaluationContext,
            registry,
            Collections.emptyList(),
            null,
            createdProducer,
            completionCallback,
            transformEvaluationState);
    executor.run();

    assertThat(finishCalled.get(), is(true));
    assertThat(completionCallback.handledResult, equalTo(result));
    assertThat(completionCallback.handledException, is(nullValue()));
  }

  @Test
  public void nullTransformEvaluatorTerminates() throws Exception {
    when(registry.forApplication(createdProducer, null)).thenReturn(null);

    DirectTransformExecutor<Object> executor =
        new DirectTransformExecutor<>(
            evaluationContext,
            registry,
            Collections.emptyList(),
            null,
            createdProducer,
            completionCallback,
            transformEvaluationState);
    executor.run();

    assertThat(completionCallback.handledResult, is(nullValue()));
    assertThat(completionCallback.handledEmpty, equalTo(true));
    assertThat(completionCallback.handledException, is(nullValue()));
  }

  @Test
  public void inputBundleProcessesEachElementFinishesAndCompletes() throws Exception {
    final TransformResult<String> result =
        StepTransformResult.<String>withoutHold(downstreamProducer).build();
    final Collection<WindowedValue<String>> elementsProcessed = new ArrayList<>();
    TransformEvaluator<String> evaluator =
        new TransformEvaluator<String>() {
          @Override
          public void processElement(WindowedValue<String> element) throws Exception {
            elementsProcessed.add(element);
          }

          @Override
          public TransformResult<String> finishBundle() throws Exception {
            return result;
          }
        };

    WindowedValue<String> foo = WindowedValue.valueInGlobalWindow("foo");
    WindowedValue<String> spam = WindowedValue.valueInGlobalWindow("spam");
    WindowedValue<String> third = WindowedValue.valueInGlobalWindow("third");
    CommittedBundle<String> inputBundle =
        bundleFactory.createBundle(created).add(foo).add(spam).add(third).commit(Instant.now());
    when(registry.<String>forApplication(downstreamProducer, inputBundle)).thenReturn(evaluator);

    DirectTransformExecutor<String> executor =
        new DirectTransformExecutor<>(
            evaluationContext,
            registry,
            Collections.emptyList(),
            inputBundle,
            downstreamProducer,
            completionCallback,
            transformEvaluationState);

    Future<?> future = Executors.newSingleThreadExecutor().submit(executor);

    evaluatorCompleted.await();
    future.get();

    assertThat(elementsProcessed, containsInAnyOrder(spam, third, foo));
    assertThat(completionCallback.handledResult, equalTo(result));
    assertThat(completionCallback.handledException, is(nullValue()));
  }

  @Test
  @SuppressWarnings("FutureReturnValueIgnored") // expected exception checked via completionCallback
  public void processElementThrowsExceptionCallsback() throws Exception {
    final TransformResult<String> result =
        StepTransformResult.<String>withoutHold(downstreamProducer).build();
    final Exception exception = new Exception();
    TransformEvaluator<String> evaluator =
        new TransformEvaluator<String>() {
          @Override
          public void processElement(WindowedValue<String> element) throws Exception {
            throw exception;
          }

          @Override
          public TransformResult<String> finishBundle() throws Exception {
            return result;
          }
        };

    WindowedValue<String> foo = WindowedValue.valueInGlobalWindow("foo");
    CommittedBundle<String> inputBundle =
        bundleFactory.createBundle(created).add(foo).commit(Instant.now());
    when(registry.<String>forApplication(downstreamProducer, inputBundle)).thenReturn(evaluator);

    DirectTransformExecutor<String> executor =
        new DirectTransformExecutor<>(
            evaluationContext,
            registry,
            Collections.emptyList(),
            inputBundle,
            downstreamProducer,
            completionCallback,
            transformEvaluationState);
    Executors.newSingleThreadExecutor().submit(executor);

    evaluatorCompleted.await();

    assertThat(completionCallback.handledResult, is(nullValue()));
    assertThat(completionCallback.handledException, Matchers.<Throwable>equalTo(exception));
  }

  @Test
  @SuppressWarnings("FutureReturnValueIgnored") // expected exception checked via completionCallback
  public void finishBundleThrowsExceptionCallsback() throws Exception {
    final Exception exception = new Exception();
    TransformEvaluator<String> evaluator =
        new TransformEvaluator<String>() {
          @Override
          public void processElement(WindowedValue<String> element) throws Exception {}

          @Override
          public TransformResult<String> finishBundle() throws Exception {
            throw exception;
          }
        };

    CommittedBundle<String> inputBundle = bundleFactory.createBundle(created).commit(Instant.now());
    when(registry.<String>forApplication(downstreamProducer, inputBundle)).thenReturn(evaluator);

    DirectTransformExecutor<String> executor =
        new DirectTransformExecutor<>(
            evaluationContext,
            registry,
            Collections.emptyList(),
            inputBundle,
            downstreamProducer,
            completionCallback,
            transformEvaluationState);
    Executors.newSingleThreadExecutor().submit(executor);

    evaluatorCompleted.await();

    assertThat(completionCallback.handledResult, is(nullValue()));
    assertThat(completionCallback.handledException, Matchers.<Throwable>equalTo(exception));
  }

  @Test
  public void callWithEnforcementAppliesEnforcement() throws Exception {
    final TransformResult<Object> result =
        StepTransformResult.withoutHold(downstreamProducer).build();

    TransformEvaluator<Object> evaluator =
        new TransformEvaluator<Object>() {
          @Override
          public void processElement(WindowedValue<Object> element) throws Exception {}

          @Override
          public TransformResult<Object> finishBundle() throws Exception {
            return result;
          }
        };

    WindowedValue<String> fooElem = WindowedValue.valueInGlobalWindow("foo");
    WindowedValue<String> barElem = WindowedValue.valueInGlobalWindow("bar");
    CommittedBundle<String> inputBundle =
        bundleFactory.createBundle(created).add(fooElem).add(barElem).commit(Instant.now());
    when(registry.forApplication(downstreamProducer, inputBundle)).thenReturn(evaluator);

    TestEnforcementFactory enforcement = new TestEnforcementFactory();
    DirectTransformExecutor<String> executor =
        new DirectTransformExecutor<>(
            evaluationContext,
            registry,
            Collections.<ModelEnforcementFactory>singleton(enforcement),
            inputBundle,
            downstreamProducer,
            completionCallback,
            transformEvaluationState);

    executor.run();
    TestEnforcement<?> testEnforcement = enforcement.instance;
    assertThat(testEnforcement.beforeElements, containsInAnyOrder(barElem, fooElem));
    assertThat(testEnforcement.afterElements, containsInAnyOrder(barElem, fooElem));
    assertThat(testEnforcement.finishedBundles, Matchers.contains(result));
  }

  @Test
  public void callWithEnforcementThrowsOnFinishPropagates() throws Exception {
    final TransformResult<Object> result = StepTransformResult.withoutHold(createdProducer).build();

    TransformEvaluator<Object> evaluator =
        new TransformEvaluator<Object>() {
          @Override
          public void processElement(WindowedValue<Object> element) throws Exception {}

          @Override
          public TransformResult<Object> finishBundle() throws Exception {
            return result;
          }
        };

    WindowedValue<String> fooBytes = WindowedValue.valueInGlobalWindow("foo");
    CommittedBundle<String> inputBundle =
        bundleFactory.createBundle(created).add(fooBytes).commit(Instant.now());
    when(registry.forApplication(downstreamProducer, inputBundle)).thenReturn(evaluator);

    DirectTransformExecutor<String> executor =
        new DirectTransformExecutor<>(
            evaluationContext,
            registry,
            Collections.<ModelEnforcementFactory>singleton(
                new ThrowingEnforcementFactory(ThrowingEnforcementFactory.When.AFTER_BUNDLE)),
            inputBundle,
            downstreamProducer,
            completionCallback,
            transformEvaluationState);

    Future<?> task = Executors.newSingleThreadExecutor().submit(executor);

    thrown.expectCause(isA(RuntimeException.class));
    thrown.expectMessage("afterFinish");
    task.get();
  }

  @Test
  public void callWithEnforcementThrowsOnElementPropagates() throws Exception {
    final TransformResult<Object> result = StepTransformResult.withoutHold(createdProducer).build();

    TransformEvaluator<Object> evaluator =
        new TransformEvaluator<Object>() {
          @Override
          public void processElement(WindowedValue<Object> element) throws Exception {}

          @Override
          public TransformResult<Object> finishBundle() throws Exception {
            return result;
          }
        };

    WindowedValue<String> fooBytes = WindowedValue.valueInGlobalWindow("foo");
    CommittedBundle<String> inputBundle =
        bundleFactory.createBundle(created).add(fooBytes).commit(Instant.now());
    when(registry.forApplication(downstreamProducer, inputBundle)).thenReturn(evaluator);

    DirectTransformExecutor<String> executor =
        new DirectTransformExecutor<>(
            evaluationContext,
            registry,
            Collections.<ModelEnforcementFactory>singleton(
                new ThrowingEnforcementFactory(ThrowingEnforcementFactory.When.AFTER_ELEMENT)),
            inputBundle,
            downstreamProducer,
            completionCallback,
            transformEvaluationState);

    Future<?> task = Executors.newSingleThreadExecutor().submit(executor);

    thrown.expectCause(isA(RuntimeException.class));
    thrown.expectMessage("afterElement");
    task.get();
  }

  private static class RegisteringCompletionCallback implements CompletionCallback {
    private TransformResult<?> handledResult = null;
    private boolean handledEmpty = false;
    private Exception handledException = null;
    private final CountDownLatch onMethod;

    private RegisteringCompletionCallback(CountDownLatch onMethod) {
      this.onMethod = onMethod;
    }

    @Override
    public CommittedResult handleResult(CommittedBundle<?> inputBundle, TransformResult<?> result) {
      handledResult = result;
      onMethod.countDown();
      @SuppressWarnings("rawtypes")
      Iterable unprocessedElements =
          result.getUnprocessedElements() == null
              ? Collections.emptyList()
              : result.getUnprocessedElements();

      Optional<? extends CommittedBundle<?>> unprocessedBundle;
      if (inputBundle == null || Iterables.isEmpty(unprocessedElements)) {
        unprocessedBundle = Optional.empty();
      } else {
        unprocessedBundle =
            Optional.<CommittedBundle<?>>of(inputBundle.withElements(unprocessedElements));
      }
      return CommittedResult.create(
          result, unprocessedBundle, Collections.emptyList(), EnumSet.noneOf(OutputType.class));
    }

    @Override
    public void handleEmpty(AppliedPTransform<?, ?, ?> transform) {
      handledEmpty = true;
      onMethod.countDown();
    }

    @Override
    public void handleException(CommittedBundle<?> inputBundle, Exception e) {
      handledException = e;
      onMethod.countDown();
    }

    @Override
    public void handleError(Error err) {
      throw err;
    }
  }

  private static class TestEnforcementFactory implements ModelEnforcementFactory {
    private TestEnforcement<?> instance;

    @Override
    public <T> TestEnforcement<T> forBundle(
        CommittedBundle<T> input, AppliedPTransform<?, ?, ?> consumer) {
      TestEnforcement<T> newEnforcement = new TestEnforcement<>();
      instance = newEnforcement;
      return newEnforcement;
    }
  }

  private static class TestEnforcement<T> implements ModelEnforcement<T> {
    private final List<WindowedValue<T>> beforeElements = new ArrayList<>();
    private final List<WindowedValue<T>> afterElements = new ArrayList<>();
    private final List<TransformResult<?>> finishedBundles = new ArrayList<>();

    @Override
    public void beforeElement(WindowedValue<T> element) {
      beforeElements.add(element);
    }

    @Override
    public void afterElement(WindowedValue<T> element) {
      afterElements.add(element);
    }

    @Override
    public void afterFinish(
        CommittedBundle<T> input,
        TransformResult<T> result,
        Iterable<? extends CommittedBundle<?>> outputs) {
      finishedBundles.add(result);
    }
  }

  private static class ThrowingEnforcementFactory implements ModelEnforcementFactory {
    private final When when;

    private ThrowingEnforcementFactory(When when) {
      this.when = when;
    }

    enum When {
      BEFORE_BUNDLE,
      BEFORE_ELEMENT,
      AFTER_ELEMENT,
      AFTER_BUNDLE
    }

    @Override
    public <T> ModelEnforcement<T> forBundle(
        CommittedBundle<T> input, AppliedPTransform<?, ?, ?> consumer) {
      if (when == When.BEFORE_BUNDLE) {
        throw new RuntimeException("forBundle");
      }
      return new ThrowingEnforcement<>();
    }

    private class ThrowingEnforcement<T> implements ModelEnforcement<T> {
      @Override
      public void beforeElement(WindowedValue<T> element) {
        if (when == When.BEFORE_ELEMENT) {
          throw new RuntimeException("beforeElement");
        }
      }

      @Override
      public void afterElement(WindowedValue<T> element) {
        if (when == When.AFTER_ELEMENT) {
          throw new RuntimeException("afterElement");
        }
      }

      @Override
      public void afterFinish(
          CommittedBundle<T> input,
          TransformResult<T> result,
          Iterable<? extends CommittedBundle<?>> outputs) {
        if (when == When.AFTER_BUNDLE) {
          throw new RuntimeException("afterFinish");
        }
      }
    }
  }
}