/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.tez.dag.app.dag.impl;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import java.io.ByteArrayOutputStream;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.DataOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;

import org.apache.hadoop.io.DataInputByteBuffer;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.yarn.event.EventHandler;
import org.apache.tez.dag.api.EdgeManagerPlugin;
import org.apache.tez.dag.api.EdgeManagerPluginContext;
import org.apache.tez.dag.api.EdgeManagerPluginDescriptor;
import org.apache.tez.dag.api.EdgeProperty;
import org.apache.tez.dag.api.EdgeProperty.DataMovementType;
import org.apache.tez.dag.api.EdgeProperty.DataSourceType;
import org.apache.tez.dag.api.EdgeProperty.SchedulingType;
import org.apache.tez.dag.api.InputDescriptor;
import org.apache.tez.dag.api.OutputDescriptor;
import org.apache.tez.dag.api.TezConfiguration;
import org.apache.tez.dag.api.TezException;
import org.apache.tez.dag.api.UserPayload;
import org.apache.tez.dag.app.dag.Task;
import org.apache.tez.dag.app.dag.Vertex;
import org.apache.tez.dag.records.TezDAGID;
import org.apache.tez.dag.records.TezTaskAttemptID;
import org.apache.tez.dag.records.TezTaskID;
import org.apache.tez.dag.records.TezVertexID;
import org.apache.tez.runtime.api.events.CompositeDataMovementEvent;
import org.apache.tez.runtime.api.events.DataMovementEvent;
import org.apache.tez.runtime.api.events.InputReadErrorEvent;
import org.apache.tez.runtime.api.impl.EventMetaData;
import org.apache.tez.runtime.api.impl.EventMetaData.EventProducerConsumerType;
import org.apache.tez.runtime.api.impl.GroupInputSpec;
import org.apache.tez.runtime.api.impl.TezEvent;
import org.apache.tez.test.EdgeManagerForTest;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.ArgumentCaptor;

import com.google.common.collect.Maps;

public class TestEdge {

  @Test (timeout = 5000)
  public void testOneToOneEdgeManager() {
    EdgeManagerPluginContext mockContext = mock(EdgeManagerPluginContext.class);
    when(mockContext.getSourceVertexName()).thenReturn("Source");
    when(mockContext.getDestinationVertexName()).thenReturn("Destination");
    when(mockContext.getSourceVertexNumTasks()).thenReturn(3);
    OneToOneEdgeManager manager = new OneToOneEdgeManager(mockContext);
    manager.initialize();
    Map<Integer, List<Integer>> destinationTaskAndInputIndices = Maps.newHashMap();
    DataMovementEvent event = DataMovementEvent.create(1, null);

    // fail when source and destination are inconsistent
    when(mockContext.getDestinationVertexNumTasks()).thenReturn(4);
    try {
      manager.routeDataMovementEventToDestination(event, 1, 1, destinationTaskAndInputIndices);
      Assert.fail();
    } catch (IllegalStateException e) {
      Assert.assertTrue(e.getMessage().contains("1-1 source and destination task counts must match"));
    }
    
    // now make it consistent
    when(mockContext.getDestinationVertexNumTasks()).thenReturn(3);
    manager.routeDataMovementEventToDestination(event, 1, 1, destinationTaskAndInputIndices);
    Assert.assertEquals(1, destinationTaskAndInputIndices.size());
    Assert.assertEquals(1, destinationTaskAndInputIndices.entrySet().iterator().next().getKey()
        .intValue());
    Assert.assertEquals(0, destinationTaskAndInputIndices.entrySet().iterator().next().getValue()
        .get(0).intValue());
  }

  @Test (timeout = 5000)
  public void testOneToOneEdgeManagerODR() {
    EdgeManagerPluginContext mockContext = mock(EdgeManagerPluginContext.class);
    when(mockContext.getSourceVertexName()).thenReturn("Source");
    when(mockContext.getDestinationVertexName()).thenReturn("Destination");
    when(mockContext.getSourceVertexNumTasks()).thenReturn(3);
    OneToOneEdgeManagerOnDemand manager = new OneToOneEdgeManagerOnDemand(mockContext);
    manager.initialize();
    Map<Integer, List<Integer>> destinationTaskAndInputIndices = Maps.newHashMap();
    DataMovementEvent event = DataMovementEvent.create(1, null);

    // fail when source and destination are inconsistent
    when(mockContext.getDestinationVertexNumTasks()).thenReturn(4);
    try {
      manager.routeDataMovementEventToDestination(event, 1, 1, destinationTaskAndInputIndices);
      Assert.fail();
    } catch (IllegalStateException e) {
      Assert.assertTrue(e.getMessage().contains("1-1 source and destination task counts must match"));
    }

    // now make it consistent
    when(mockContext.getDestinationVertexNumTasks()).thenReturn(3);
    manager.routeDataMovementEventToDestination(event, 1, 1, destinationTaskAndInputIndices);
    Assert.assertEquals(1, destinationTaskAndInputIndices.size());
    Assert.assertEquals(1, destinationTaskAndInputIndices.entrySet().iterator().next().getKey()
        .intValue());
    Assert.assertEquals(0, destinationTaskAndInputIndices.entrySet().iterator().next().getValue()
        .get(0).intValue());
  }

  @Test(timeout = 5000)
  public void testScatterGatherManager() {
    EdgeManagerPluginContext mockContext = mock(EdgeManagerPluginContext.class);
    when(mockContext.getSourceVertexName()).thenReturn("Source");
    when(mockContext.getDestinationVertexName()).thenReturn("Destination");
    ScatterGatherEdgeManager manager = new ScatterGatherEdgeManager(mockContext);
    manager.initialize();

    when(mockContext.getDestinationVertexNumTasks()).thenReturn(-1);
    try {
      manager.getNumSourceTaskPhysicalOutputs(0);
      Assert.fail();
    } catch (IllegalArgumentException e) {
      e.printStackTrace();
      Assert.assertTrue(e.getMessage()
          .contains("ScatteGather edge manager must have destination vertex task parallelism specified"));
    }
    when(mockContext.getDestinationVertexNumTasks()).thenReturn(0);
    manager.getNumSourceTaskPhysicalOutputs(0);
  }

  @SuppressWarnings({ "rawtypes" })
  @Test (timeout = 5000)
  public void testCompositeEventHandling() throws TezException {
    EventHandler eventHandler = mock(EventHandler.class);
    EdgeProperty edgeProp = EdgeProperty.create(DataMovementType.SCATTER_GATHER,
        DataSourceType.PERSISTED, SchedulingType.SEQUENTIAL, mock(OutputDescriptor.class),
        mock(InputDescriptor.class));
    Edge edge = new Edge(edgeProp, eventHandler, new TezConfiguration());
    
    TezVertexID srcVertexID = createVertexID(1);
    TezVertexID destVertexID = createVertexID(2);
    LinkedHashMap<TezTaskID, Task> srcTasks = mockTasks(srcVertexID, 1);
    LinkedHashMap<TezTaskID, Task> destTasks = mockTasks(destVertexID, 5);
    
    TezTaskID srcTaskID = srcTasks.keySet().iterator().next();
    
    Vertex srcVertex = mockVertex("src", srcVertexID, srcTasks);
    Vertex destVertex = mockVertex("dest", destVertexID, destTasks);
    
    edge.setSourceVertex(srcVertex);
    edge.setDestinationVertex(destVertex);
    edge.initialize();
    
    TezTaskAttemptID srcTAID = createTAIDForTest(srcTaskID, 2); // Task0, Attempt 0
    
    EventMetaData srcMeta = new EventMetaData(EventProducerConsumerType.OUTPUT, "consumerVertex", "producerVertex", srcTAID);
    
    // Verification via a CompositeEvent
    CompositeDataMovementEvent cdmEvent = CompositeDataMovementEvent.create(0, destTasks.size(),
        ByteBuffer.wrap("bytes".getBytes()));
    cdmEvent.setVersion(2); // AttemptNum
    TezEvent tezEvent = new TezEvent(cdmEvent, srcMeta);
    // Event setup to look like it would after the Vertex is done with it.

    edge.sendTezEventToDestinationTasks(tezEvent);
    verifyEvents(srcTAID, destTasks);

    // Same Verification via regular DataMovementEvents
    // Reset the mock
    resetTaskMocks(destTasks.values());

    for (int i = 0 ; i < destTasks.size() ; i++) {
      DataMovementEvent dmEvent = DataMovementEvent.create(i, ByteBuffer.wrap("bytes".getBytes()));
      dmEvent.setVersion(2);
      tezEvent = new TezEvent(dmEvent, srcMeta);
      edge.sendTezEventToDestinationTasks(tezEvent);
    }
    verifyEvents(srcTAID, destTasks);
  }

  private void verifyEvents(TezTaskAttemptID srcTAID, LinkedHashMap<TezTaskID, Task> destTasks) {
    int count = 0;

    for (Entry<TezTaskID, Task> taskEntry : destTasks.entrySet()) {
      Task mockTask = taskEntry.getValue();
      ArgumentCaptor<TezEvent> args = ArgumentCaptor.forClass(TezEvent.class);
      verify(mockTask, times(1)).registerTezEvent(args.capture());
      TezEvent capturedEvent = args.getValue();

      DataMovementEvent dmEvent = (DataMovementEvent) capturedEvent.getEvent();
      assertEquals(srcTAID.getId(), dmEvent.getVersion());
      assertEquals(count++, dmEvent.getSourceIndex());
      assertEquals(srcTAID.getTaskID().getId(), dmEvent.getTargetIndex());
      byte[] res = new byte[dmEvent.getUserPayload().limit() - dmEvent.getUserPayload().position()];
      dmEvent.getUserPayload().slice().get(res);
      assertTrue(Arrays.equals("bytes".getBytes(), res));
    }
  }

  private void resetTaskMocks(Collection<Task> tasks) {
    for (Task task : tasks) {
      TezTaskID taskID = task.getTaskId();
      reset(task);
      doReturn(taskID).when(task).getTaskId();
    }
  }

  private LinkedHashMap<TezTaskID, Task> mockTasks(TezVertexID vertexID, int numTasks) {
    LinkedHashMap<TezTaskID, Task> tasks = new LinkedHashMap<TezTaskID, Task>();
    for (int i = 0 ; i < numTasks ; i++) {
      Task task = mock(Task.class);
      TezTaskID taskID = TezTaskID.getInstance(vertexID, i);
      doReturn(taskID).when(task).getTaskId();
      tasks.put(taskID, task);
    }
    return tasks;
  }
  
  private Vertex mockVertex(String name, TezVertexID vertexID, LinkedHashMap<TezTaskID, Task> tasks) {
    Vertex vertex = mock(Vertex.class);
    doReturn(vertexID).when(vertex).getVertexId();
    doReturn(name).when(vertex).getName();
    doReturn(tasks).when(vertex).getTasks();
    doReturn(tasks.size()).when(vertex).getTotalTasks();
    for (Entry<TezTaskID, Task> entry : tasks.entrySet()) {
      doReturn(entry.getValue()).when(vertex).getTask(eq(entry.getKey()));
      doReturn(entry.getValue()).when(vertex).getTask(eq(entry.getKey().getId()));
    }
    return vertex;
  }
  
  private TezVertexID createVertexID(int id) {
    TezDAGID dagID = TezDAGID.getInstance("1000", 1, 1);
    TezVertexID vertexID = TezVertexID.getInstance(dagID, id);
    return vertexID;
  }
  
  private TezTaskAttemptID createTAIDForTest(TezTaskID taskID, int taId) {
    TezTaskAttemptID taskAttemptID = TezTaskAttemptID.getInstance(taskID, taId);
    return taskAttemptID;
  }

  @Test(timeout = 5000)
  public void testInvalidPhysicalInputCount() throws Exception {
    EventHandler mockEventHandler = mock(EventHandler.class);
    Edge edge = new Edge(EdgeProperty.create(
        EdgeManagerPluginDescriptor.create(CustomEdgeManagerWithInvalidReturnValue.class.getName())
          .setUserPayload(new CustomEdgeManagerWithInvalidReturnValue.EdgeManagerConfig(-1,1,1,1).toUserPayload()),
        DataSourceType.PERSISTED,
        SchedulingType.SEQUENTIAL,
        OutputDescriptor.create(""),
        InputDescriptor.create("")), mockEventHandler, new TezConfiguration());
    TezVertexID v1Id = createVertexID(1);
    TezVertexID v2Id = createVertexID(2);
    edge.setSourceVertex(mockVertex("v1", v1Id, new LinkedHashMap<TezTaskID, Task>()));
    edge.setDestinationVertex(mockVertex("v2", v2Id, new LinkedHashMap<TezTaskID, Task>()));
    edge.initialize();
    try {
      edge.getDestinationSpec(0);
      Assert.fail();
    } catch (AMUserCodeException e) {
      e.printStackTrace();
      assertTrue(e.getCause().getMessage().contains("PhysicalInputCount should not be negative"));
    }
  }

  @Test(timeout = 5000)
  public void testInvalidPhysicalOutputCount() throws Exception {
    EventHandler mockEventHandler = mock(EventHandler.class);
    Edge edge = new Edge(EdgeProperty.create(
        EdgeManagerPluginDescriptor.create(CustomEdgeManagerWithInvalidReturnValue.class.getName())
          .setUserPayload(new CustomEdgeManagerWithInvalidReturnValue.EdgeManagerConfig(1,-1,1,1).toUserPayload()),
        DataSourceType.PERSISTED,
        SchedulingType.SEQUENTIAL,
        OutputDescriptor.create(""),
        InputDescriptor.create("")), mockEventHandler, new TezConfiguration());
    TezVertexID v1Id = createVertexID(1);
    TezVertexID v2Id = createVertexID(2);
    edge.setSourceVertex(mockVertex("v1", v1Id, new LinkedHashMap<TezTaskID, Task>()));
    edge.setDestinationVertex(mockVertex("v2", v2Id, new LinkedHashMap<TezTaskID, Task>()));
    edge.initialize();
    try {
      edge.getSourceSpec(0);
      Assert.fail();
    } catch (AMUserCodeException e) {
      e.printStackTrace();
      assertTrue(e.getCause().getMessage().contains("PhysicalOutputCount should not be negative"));
    }
  }

  @Test(timeout = 5000)
  public void testInvalidConsumerNumber() throws Exception {
    EventHandler mockEventHandler = mock(EventHandler.class);
    Edge edge = new Edge(EdgeProperty.create(
        EdgeManagerPluginDescriptor.create(CustomEdgeManagerWithInvalidReturnValue.class.getName())
          .setUserPayload(new CustomEdgeManagerWithInvalidReturnValue.EdgeManagerConfig(1,1,0,1).toUserPayload()),
        DataSourceType.PERSISTED,
        SchedulingType.SEQUENTIAL,
        OutputDescriptor.create(""),
        InputDescriptor.create("")), mockEventHandler, new TezConfiguration());
    TezVertexID v1Id = createVertexID(1);
    TezVertexID v2Id = createVertexID(2);
    edge.setSourceVertex(mockVertex("v1", v1Id, new LinkedHashMap<TezTaskID, Task>()));
    edge.setDestinationVertex(mockVertex("v2", v2Id, new LinkedHashMap<TezTaskID, Task>()));
    edge.initialize();
    try {
      TezEvent ireEvent = new TezEvent(InputReadErrorEvent.create("diag", 0, 1),
          new EventMetaData(EventProducerConsumerType.INPUT, "v2", "v1",
              TezTaskAttemptID.getInstance(TezTaskID.getInstance(v2Id, 1), 1)));
      edge.sendTezEventToSourceTasks(ireEvent);
      Assert.fail();
    } catch (AMUserCodeException e) {
      e.printStackTrace();
      assertTrue(e.getCause().getMessage().contains("ConsumerTaskNum must be positive"));
    }
  }

  @Test(timeout = 5000)
  public void testInvalidSourceTaskIndex() throws Exception {
    EventHandler mockEventHandler = mock(EventHandler.class);
    Edge edge = new Edge(EdgeProperty.create(
        EdgeManagerPluginDescriptor.create(CustomEdgeManagerWithInvalidReturnValue.class.getName())
          .setUserPayload(new CustomEdgeManagerWithInvalidReturnValue.EdgeManagerConfig(1,1,1,-1).toUserPayload()),
        DataSourceType.PERSISTED,
        SchedulingType.SEQUENTIAL,
        OutputDescriptor.create(""),
        InputDescriptor.create("")), mockEventHandler, new TezConfiguration());
    TezVertexID v1Id = createVertexID(1);
    TezVertexID v2Id = createVertexID(2);
    edge.setSourceVertex(mockVertex("v1", v1Id, new LinkedHashMap<TezTaskID, Task>()));
    edge.setDestinationVertex(mockVertex("v2", v2Id, new LinkedHashMap<TezTaskID, Task>()));
    edge.initialize();
    try {
      TezEvent ireEvent = new TezEvent(InputReadErrorEvent.create("diag", 0, 1),
          new EventMetaData(EventProducerConsumerType.INPUT, "v2", "v1",
              TezTaskAttemptID.getInstance(TezTaskID.getInstance(v2Id, 1), 1)));
      edge.sendTezEventToSourceTasks(ireEvent);
      Assert.fail();
    } catch (AMUserCodeException e) {
      e.printStackTrace();
      assertTrue(e.getCause().getMessage().contains("SourceTaskIndex should not be negative"));
    }
  }

  public static class CustomEdgeManagerWithInvalidReturnValue extends EdgeManagerPlugin {

    public static class EdgeManagerConfig implements Writable {
      int physicalInput = 1 ;
      int physicalOutput = 1;
      int consumerNumber = 1;
      int sourceTaskIndex = 1;

      public EdgeManagerConfig() {

      }

      public EdgeManagerConfig(int physicalInput, int physicalOutput,
          int consumerNumber, int sourceTaskIndex) {
        super();
        this.physicalInput = physicalInput;
        this.physicalOutput = physicalOutput;
        this.consumerNumber = consumerNumber;
        this.sourceTaskIndex = sourceTaskIndex;
      }

      public UserPayload toUserPayload() throws IOException {
        ByteArrayOutputStream bos = new ByteArrayOutputStream();
        DataOutput out = new DataOutputStream(bos);
        write(out);
        return UserPayload.create(ByteBuffer.wrap(bos.toByteArray()));
      }

      public static EdgeManagerConfig fromUserPayload(UserPayload payload)
          throws IOException {
        EdgeManagerConfig emConf = new EdgeManagerConfig();
        DataInputByteBuffer in  = new DataInputByteBuffer();
        in.reset(payload.getPayload());
        emConf.readFields(in);
        return emConf;
      }

      @Override
      public void write(DataOutput out) throws IOException {
        out.writeInt(physicalInput);
        out.writeInt(physicalOutput);
        out.writeInt(consumerNumber);
        out.writeInt(sourceTaskIndex);
      }

      @Override
      public void readFields(DataInput in) throws IOException {
        physicalInput = in.readInt();
        physicalOutput = in.readInt();
        consumerNumber = in.readInt();
        sourceTaskIndex = in.readInt();
      }
    }

    EdgeManagerConfig emConf;

    public CustomEdgeManagerWithInvalidReturnValue(
        EdgeManagerPluginContext context) {
      super(context);
    }

    @Override
    public void initialize() throws Exception {
      emConf = EdgeManagerConfig.fromUserPayload(getContext().getUserPayload());
    }

    @Override
    public int getNumDestinationTaskPhysicalInputs(int destinationTaskIndex)
        throws Exception {
      return emConf.physicalInput;
    }

    @Override
    public int getNumSourceTaskPhysicalOutputs(int sourceTaskIndex)
        throws Exception {
      return emConf.physicalOutput;
    }

    @Override
    public void routeDataMovementEventToDestination(DataMovementEvent event,
        int sourceTaskIndex, int sourceOutputIndex,
        Map<Integer, List<Integer>> destinationTaskAndInputIndices)
        throws Exception {
    }

    @Override
    public void routeInputSourceTaskFailedEventToDestination(
        int sourceTaskIndex,
        Map<Integer, List<Integer>> destinationTaskAndInputIndices)
        throws Exception {
    }

    @Override
    public int getNumDestinationConsumerTasks(int sourceTaskIndex)
        throws Exception {
      return emConf.consumerNumber;
    }

    @Override
    public int routeInputErrorEventToSource(InputReadErrorEvent event,
        int destinationTaskIndex, int destinationFailedInputIndex)
        throws Exception {
      return emConf.sourceTaskIndex;
    }
  }

  @Test(timeout = 5000)
  public void testEdgeManagerPluginCtxGetVertexGroupName() throws TezException {
    EdgeManagerPluginDescriptor edgeManagerDescriptor =
      EdgeManagerPluginDescriptor.create(EdgeManagerForTest.class.getName());
    EdgeProperty edgeProp = EdgeProperty.create(edgeManagerDescriptor,
      DataSourceType.PERSISTED, SchedulingType.SEQUENTIAL, OutputDescriptor.create("Out"),
      InputDescriptor.create("In"));
    Edge edge = new Edge(edgeProp, null, null);

    Vertex srcV = mock(Vertex.class), destV = mock(Vertex.class);
    String srcName = "srcV", destName = "destV";
    when(srcV.getName()).thenReturn(srcName);
    when(destV.getName()).thenReturn(destName);
    edge.setSourceVertex(srcV);
    edge.setDestinationVertex(destV);

    assertNull(edge.edgeManager.getContext().getVertexGroupName());

    String group = "group";
    when(destV.getGroupInputSpecList())
      .thenReturn(Arrays.asList(new GroupInputSpec(group, Arrays.asList("v1", "v3"), null)));
    assertNull(edge.edgeManager.getContext().getVertexGroupName());

    when(destV.getGroupInputSpecList())
      .thenReturn(Arrays.asList(new GroupInputSpec(group, Arrays.asList(srcName, "v3"), null)));
    assertEquals(group, edge.edgeManager.getContext().getVertexGroupName());
  }
}