/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.nemo.compiler.frontend.beam.transform;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.DoFnSchemaInformation;
import org.apache.beam.sdk.transforms.View;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.nemo.common.ir.OutputCollector;
import org.apache.nemo.common.ir.vertex.transform.Transform;
import org.apache.nemo.common.punctuation.Watermark;
import org.apache.nemo.compiler.frontend.beam.NemoPipelineOptions;
import org.apache.nemo.compiler.frontend.beam.SideInputElement;
import org.apache.reef.io.Tuple;
import org.junit.Before;
import org.junit.Test;

import java.util.*;

import static org.junit.Assert.*;
import static org.mockito.Mockito.mock;

public final class DoFnTransformTest {

  // views and windows for testing side inputs
  private PCollectionView<Iterable<String>> view1;
  private PCollectionView<Iterable<String>> view2;

  private final static Coder NULL_INPUT_CODER = null;
  private final static Map<TupleTag<?>, Coder<?>> NULL_OUTPUT_CODERS = null;

  @Before
  public void setUp() {
    Pipeline.create().apply(Create.of("1"));
    view1 = Pipeline.create().apply(Create.of("1")).apply(View.asIterable());
    view2 = Pipeline.create().apply(Create.of("2")).apply(View.asIterable());
  }

  @Test
  @SuppressWarnings("unchecked")
  public void testSingleOutput() {

    final TupleTag<String> outputTag = new TupleTag<>("main-output");

    final DoFnTransform<String, String> doFnTransform =
      new DoFnTransform<>(
        new IdentityDoFn<>(),
        NULL_INPUT_CODER,
        NULL_OUTPUT_CODERS,
        outputTag,
        Collections.emptyList(),
        WindowingStrategy.globalDefault(),
        PipelineOptionsFactory.as(NemoPipelineOptions.class),
        DisplayData.none(),
        DoFnSchemaInformation.create(),
        Collections.emptyMap());

    final Transform.Context context = mock(Transform.Context.class);
    final OutputCollector<WindowedValue<String>> oc = new TestOutputCollector<>();
    doFnTransform.prepare(context, oc);

    doFnTransform.onData(WindowedValue.valueInGlobalWindow("Hello"));

    assertEquals(((TestOutputCollector<String>) oc).outputs.get(0), WindowedValue.valueInGlobalWindow("Hello"));

    doFnTransform.close();
  }


  @Test
  @SuppressWarnings("unchecked")
  public void testCountBundle() {

    final TupleTag<String> outputTag = new TupleTag<>("main-output");
    final NemoPipelineOptions pipelineOptions = PipelineOptionsFactory.as(NemoPipelineOptions.class);
    pipelineOptions.setMaxBundleSize(3L);
    pipelineOptions.setMaxBundleTimeMills(10000000L);

    final List<Integer> bundleOutput = new ArrayList<>();

    final DoFnTransform<String, String> doFnTransform =
      new DoFnTransform<>(
        new BundleTestDoFn(bundleOutput),
        NULL_INPUT_CODER,
        NULL_OUTPUT_CODERS,
        outputTag,
        Collections.emptyList(),
        WindowingStrategy.globalDefault(),
        pipelineOptions,
        DisplayData.none(),
        DoFnSchemaInformation.create(),
        Collections.emptyMap());

    final Transform.Context context = mock(Transform.Context.class);
    final OutputCollector<WindowedValue<String>> oc = new TestOutputCollector<>();
    doFnTransform.prepare(context, oc);

    doFnTransform.onData(WindowedValue.valueInGlobalWindow("a"));
    doFnTransform.onData(WindowedValue.valueInGlobalWindow("a"));
    doFnTransform.onData(WindowedValue.valueInGlobalWindow("a"));

    assertEquals(3, (int) bundleOutput.get(0));

    bundleOutput.clear();

    doFnTransform.onData(WindowedValue.valueInGlobalWindow("a"));
    doFnTransform.onData(WindowedValue.valueInGlobalWindow("a"));
    doFnTransform.onData(WindowedValue.valueInGlobalWindow("a"));

    assertEquals(3, (int) bundleOutput.get(0));

    doFnTransform.close();
  }

  @Test
  @SuppressWarnings("unchecked")
  public void testTimeBundle() {

    final long maxBundleTimeMills = 1000L;
    final TupleTag<String> outputTag = new TupleTag<>("main-output");
    final NemoPipelineOptions pipelineOptions = PipelineOptionsFactory.as(NemoPipelineOptions.class);
    pipelineOptions.setMaxBundleSize(10000000L);
    pipelineOptions.setMaxBundleTimeMills(maxBundleTimeMills);

    final List<Integer> bundleOutput = new ArrayList<>();

    final DoFnTransform<String, String> doFnTransform =
      new DoFnTransform<>(
        new BundleTestDoFn(bundleOutput),
        NULL_INPUT_CODER,
        NULL_OUTPUT_CODERS,
        outputTag,
        Collections.emptyList(),
        WindowingStrategy.globalDefault(),
        pipelineOptions,
        DisplayData.none(),
        DoFnSchemaInformation.create(),
        Collections.emptyMap());

    final Transform.Context context = mock(Transform.Context.class);
    final OutputCollector<WindowedValue<String>> oc = new TestOutputCollector<>();

    long startTime = System.currentTimeMillis();
    doFnTransform.prepare(context, oc);

    int count = 0;
    while (bundleOutput.isEmpty()) {
      doFnTransform.onData(WindowedValue.valueInGlobalWindow("a"));
      count += 1;
      try {
        Thread.sleep(10);
      } catch (InterruptedException e) {
        e.printStackTrace();
        throw new RuntimeException(e);
      }
    }

    long endTime = System.currentTimeMillis();
    assertEquals(count, (int) bundleOutput.get(0));
    assertTrue(endTime - startTime >= maxBundleTimeMills);

    doFnTransform.close();
  }

  @Test
  @SuppressWarnings("unchecked")
  public void testMultiOutputOutput() {

    TupleTag<String> mainOutput = new TupleTag<>("main-output");
    TupleTag<String> additionalOutput1 = new TupleTag<>("output-1");
    TupleTag<String> additionalOutput2 = new TupleTag<>("output-2");

    ImmutableList<TupleTag<?>> tags = ImmutableList.of(additionalOutput1, additionalOutput2);

    ImmutableMap<String, String> tagsMap =
      ImmutableMap.<String, String>builder()
        .put(additionalOutput1.getId(), additionalOutput1.getId())
        .put(additionalOutput2.getId(), additionalOutput2.getId())
        .build();

    final DoFnTransform<String, String> doFnTransform =
      new DoFnTransform<>(
        new MultiOutputDoFn(additionalOutput1, additionalOutput2),
        NULL_INPUT_CODER,
        NULL_OUTPUT_CODERS,
        mainOutput,
        tags,
        WindowingStrategy.globalDefault(),
        PipelineOptionsFactory.as(NemoPipelineOptions.class),
        DisplayData.none(),
        DoFnSchemaInformation.create(),
        Collections.emptyMap());

    // mock context
    final Transform.Context context = mock(Transform.Context.class);

    final OutputCollector<WindowedValue<String>> oc = new TestOutputCollector<>();
    doFnTransform.prepare(context, oc);

    doFnTransform.onData(WindowedValue.valueInGlobalWindow("one"));
    doFnTransform.onData(WindowedValue.valueInGlobalWindow("two"));
    doFnTransform.onData(WindowedValue.valueInGlobalWindow("hello"));

    // main output
    assertEquals(WindowedValue.valueInGlobalWindow("got: hello"),
      ((TestOutputCollector<String>) oc).outputs.get(0));

    // additional output 1
    assertTrue(((TestOutputCollector<String>) oc).getTaggedOutputs().contains(
      new Tuple<>(additionalOutput1.getId(), WindowedValue.valueInGlobalWindow("extra: one"))
    ));
    assertTrue(((TestOutputCollector<String>) oc).getTaggedOutputs().contains(
      new Tuple<>(additionalOutput1.getId(), WindowedValue.valueInGlobalWindow("got: hello"))
    ));

    // additional output 2
    assertTrue(((TestOutputCollector<String>) oc).getTaggedOutputs().contains(
      new Tuple<>(additionalOutput2.getId(), WindowedValue.valueInGlobalWindow("extra: two"))
    ));
    assertTrue(((TestOutputCollector<String>) oc).getTaggedOutputs().contains(
      new Tuple<>(additionalOutput2.getId(), WindowedValue.valueInGlobalWindow("got: hello"))
    ));

    doFnTransform.close();
  }

  @Test
  public void testSideInputs() {
    // mock context
    final Transform.Context context = mock(Transform.Context.class);
    TupleTag<Tuple<String, Iterable<String>>> outputTag = new TupleTag<>("main-output");

    WindowedValue<String> firstElement = WindowedValue.valueInGlobalWindow("first");
    WindowedValue<String> secondElement = WindowedValue.valueInGlobalWindow("second");

    SideInputElement firstSideinput = new SideInputElement<>(0, ImmutableList.of("1"));
    SideInputElement secondSideinput = new SideInputElement(1, ImmutableList.of("2"));

    final Map<Integer, PCollectionView<?>> sideInputMap = new HashMap<>();
    sideInputMap.put(firstSideinput.getSideInputIndex(), view1);
    sideInputMap.put(secondSideinput.getSideInputIndex(), view2);
    final PushBackDoFnTransform<String, String> doFnTransform =
      new PushBackDoFnTransform(
        new SimpleSideInputDoFn<String>(view1, view2),
        NULL_INPUT_CODER,
        NULL_OUTPUT_CODERS,
        outputTag,
        Collections.emptyList(),
        WindowingStrategy.globalDefault(),
        sideInputMap, /* side inputs */
        PipelineOptionsFactory.as(NemoPipelineOptions.class),
        DisplayData.none(),
        DoFnSchemaInformation.create(),
        Collections.emptyMap());

    final TestOutputCollector<String> oc = new TestOutputCollector<>();
    doFnTransform.prepare(context, oc);

    // Main input first, Side inputs later
    doFnTransform.onData(firstElement);

    doFnTransform.onData(WindowedValue.valueInGlobalWindow(firstSideinput));
    doFnTransform.onData(WindowedValue.valueInGlobalWindow(secondSideinput));
    assertEquals(
      WindowedValue.valueInGlobalWindow(
        concat(firstElement.getValue(), firstSideinput.getSideInputValue(), secondSideinput.getSideInputValue())),
      oc.getOutput().get(0));

    // Side inputs first, Main input later
    doFnTransform.onData(secondElement);
    assertEquals(
      WindowedValue.valueInGlobalWindow(
        concat(secondElement.getValue(), firstSideinput.getSideInputValue(), secondSideinput.getSideInputValue())),
      oc.getOutput().get(1));

    // There should be only 2 final outputs
    assertEquals(2, oc.getOutput().size());

    // The side inputs should be "READY"
    assertTrue(doFnTransform.getSideInputReader().isReady(view1, GlobalWindow.INSTANCE));
    assertTrue(doFnTransform.getSideInputReader().isReady(view2, GlobalWindow.INSTANCE));

    // This watermark should remove the side inputs. (Now should be "NOT READY")
    doFnTransform.onWatermark(new Watermark(GlobalWindow.TIMESTAMP_MAX_VALUE.getMillis()));
    Iterable materializedSideInput1 = doFnTransform.getSideInputReader().get(view1, GlobalWindow.INSTANCE);
    Iterable materializedSideInput2 = doFnTransform.getSideInputReader().get(view2, GlobalWindow.INSTANCE);
    assertFalse(materializedSideInput1.iterator().hasNext());
    assertFalse(materializedSideInput2.iterator().hasNext());

    // There should be only 2 final outputs
    doFnTransform.close();
    assertEquals(2, oc.getOutput().size());
  }


  /**
   * Bundle test do fn.
   */
  private static class BundleTestDoFn extends DoFn<String, String> {
    int count;

    private final List<Integer> bundleOutput;

    BundleTestDoFn(final List<Integer> bundleOutput) {
      this.bundleOutput = bundleOutput;
    }

    @ProcessElement
    public void processElement(final ProcessContext c) throws Exception {
      count += 1;
      c.output(c.element());
    }

    @StartBundle
    public void startBundle(final StartBundleContext c) {
      count = 0;
    }

    @FinishBundle
    public void finishBundle(final FinishBundleContext c) {
      bundleOutput.add(count);
    }
  }

  /**
   * Identitiy do fn.
   *
   * @param <T> type
   */
  private static class IdentityDoFn<T> extends DoFn<T, T> {

    @ProcessElement
    public void processElement(final ProcessContext c) throws Exception {
      c.output(c.element());
    }
  }

  /**
   * Side input do fn.
   *
   * @param <T> type
   */
  private static class SimpleSideInputDoFn<T> extends DoFn<T, String> {
    private final PCollectionView<?> view1;
    private final PCollectionView<?> view2;

    public SimpleSideInputDoFn(final PCollectionView<?> view1,
                               final PCollectionView<?> view2) {
      this.view1 = view1;
      this.view2 = view2;
    }

    @ProcessElement
    public void processElement(final ProcessContext c) throws Exception {
      final T element = c.element();
      final Object view1Value = c.sideInput(view1);
      final Object view2Value = c.sideInput(view2);

      c.output(concat(element, view1Value, view2Value));
    }
  }

  private static String concat(final Object obj1, final Object obj2, final Object obj3) {
    return obj1.toString() + " / " + obj2 + " / " + obj3;
  }


  /**
   * Multi output do fn.
   */
  private static class MultiOutputDoFn extends DoFn<String, String> {
    private TupleTag<String> additionalOutput1;
    private TupleTag<String> additionalOutput2;

    public MultiOutputDoFn(TupleTag<String> additionalOutput1, TupleTag<String> additionalOutput2) {
      this.additionalOutput1 = additionalOutput1;
      this.additionalOutput2 = additionalOutput2;
    }

    @ProcessElement
    public void processElement(ProcessContext c) throws Exception {
      if ("one".equals(c.element())) {
        c.output(additionalOutput1, "extra: one");
      } else if ("two".equals(c.element())) {
        c.output(additionalOutput2, "extra: two");
      } else {
        c.output("got: " + c.element());
        c.output(additionalOutput1, "got: " + c.element());
        c.output(additionalOutput2, "got: " + c.element());
      }
    }
  }
}