/*
 * Copyright (C) 2017 Seoul National University
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *         http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package edu.snu.nemo.common.ir.vertex;

import edu.snu.nemo.common.dag.DAG;
import edu.snu.nemo.common.dag.DAGBuilder;
import edu.snu.nemo.common.ir.edge.IREdge;
import edu.snu.nemo.common.ir.edge.executionproperty.DuplicateEdgeGroupProperty;
import edu.snu.nemo.common.ir.edge.executionproperty.DuplicateEdgeGroupPropertyValue;
import edu.snu.nemo.common.ir.executionproperty.ExecutionProperty;

import java.io.Serializable;
import java.util.*;
import java.util.function.IntPredicate;
import java.util.stream.Collectors;

/**
 * IRVertex that contains a partial DAG that is iterative.
 */
public final class LoopVertex extends IRVertex {
  private static int duplicateEdgeGroupId = 0;
  private final DAGBuilder<IRVertex, IREdge> builder; // Contains DAG information
  private final String compositeTransformFullName;

  private final Map<IRVertex, Set<IREdge>> dagIncomingEdges; // for the initial iteration
  private final Map<IRVertex, Set<IREdge>> iterativeIncomingEdges; // Edges from previous iterations connected internal.
  private final Map<IRVertex, Set<IREdge>> nonIterativeIncomingEdges; // Edges from outside previous iterations.
  private final Map<IRVertex, Set<IREdge>> dagOutgoingEdges; // for the final iteration
  private final Map<IREdge, IREdge> edgeWithLoopToEdgeWithInternalVertex;
  private final Map<IREdge, IREdge> edgeWithInternalVertexToEdgeWithLoop;

  private Integer maxNumberOfIterations;
  private IntPredicate terminationCondition;

  /**
   * The LoopVertex constructor.
   * @param compositeTransformFullName full name of the composite transform.
   */
  public LoopVertex(final String compositeTransformFullName) {
    super();
    this.builder = new DAGBuilder<>();
    this.compositeTransformFullName = compositeTransformFullName;
    this.dagIncomingEdges = new HashMap<>();
    this.iterativeIncomingEdges = new HashMap<>();
    this.nonIterativeIncomingEdges = new HashMap<>();
    this.dagOutgoingEdges = new HashMap<>();
    this.edgeWithLoopToEdgeWithInternalVertex = new HashMap<>();
    this.edgeWithInternalVertexToEdgeWithLoop = new HashMap<>();
    this.maxNumberOfIterations = 1; // 1 is the default number of iterations.
    this.terminationCondition = (IntPredicate & Serializable) (integer -> false); // nothing much yet.
  }

  @Override
  public LoopVertex getClone() {
    final LoopVertex newLoopVertex = new LoopVertex(compositeTransformFullName);

    // Copy all elements to the clone
    final DAG<IRVertex, IREdge> dagToCopy = this.getDAG();
    dagToCopy.topologicalDo(v -> {
      newLoopVertex.getBuilder().addVertex(v, dagToCopy);
      dagToCopy.getIncomingEdgesOf(v).forEach(newLoopVertex.getBuilder()::connectVertices);
    });
    this.dagIncomingEdges.forEach(((v, es) -> es.forEach(newLoopVertex::addDagIncomingEdge)));
    this.iterativeIncomingEdges.forEach((v, es) -> es.forEach(newLoopVertex::addIterativeIncomingEdge));
    this.nonIterativeIncomingEdges.forEach((v, es) -> es.forEach(newLoopVertex::addNonIterativeIncomingEdge));
    this.dagOutgoingEdges.forEach(((v, es) -> es.forEach(newLoopVertex::addDagOutgoingEdge)));
    this.edgeWithLoopToEdgeWithInternalVertex.forEach((eLoop, eInternal)
        -> newLoopVertex.mapEdgeWithLoop(eLoop, eInternal));
    newLoopVertex.setMaxNumberOfIterations(maxNumberOfIterations);
    newLoopVertex.setTerminationCondition(terminationCondition);

    this.copyExecutionPropertiesTo(newLoopVertex);
    return newLoopVertex;
  }

  /**
   * @return DAGBuilder of the LoopVertex.
   */
  public DAGBuilder<IRVertex, IREdge> getBuilder() {
    return builder;
  }

  /**
   * @return the DAG of rthe LoopVertex
   */
  public DAG<IRVertex, IREdge> getDAG() {
    return builder.buildWithoutSourceSinkCheck();
  }

  /**
   * @return the full name of the composite transform.
   */
  public String getName() {
    return compositeTransformFullName;
  }

  /**
   * Maps an edge from/to loop with the corresponding edge from/to internal vertex.
   * @param edgeWithLoop an edge from/to loop
   * @param edgeWithInternalVertex the corresponding edge from/to internal vertex
   */
  public void mapEdgeWithLoop(final IREdge edgeWithLoop, final IREdge edgeWithInternalVertex) {
    this.edgeWithLoopToEdgeWithInternalVertex.put(edgeWithLoop, edgeWithInternalVertex);
    this.edgeWithInternalVertexToEdgeWithLoop.put(edgeWithInternalVertex, edgeWithLoop);
  }

  /**
   * @param edgeWithInternalVertex an edge with internal vertex
   * @return the corresponding edge with loop for the specified edge with internal vertex
   */
  public IREdge getEdgeWithLoop(final IREdge edgeWithInternalVertex) {
    return this.edgeWithInternalVertexToEdgeWithLoop.get(edgeWithInternalVertex);
  }

  /**
   * Adds the incoming edge of the contained DAG.
   * @param edge edge to add.
   */
  public void addDagIncomingEdge(final IREdge edge) {
    this.dagIncomingEdges.putIfAbsent(edge.getDst(), new HashSet<>());
    this.dagIncomingEdges.get(edge.getDst()).add(edge);
  }

  /**
   * @return incoming edges of the contained DAG.
   */
  public Map<IRVertex, Set<IREdge>> getDagIncomingEdges() {
    return this.dagIncomingEdges;
  }

  /**
   * Adds an iterative incoming edge, from the previous iteration, but connection internally.
   * @param edge edge to add.
   */
  public void addIterativeIncomingEdge(final IREdge edge) {
    this.iterativeIncomingEdges.putIfAbsent(edge.getDst(), new HashSet<>());
    this.iterativeIncomingEdges.get(edge.getDst()).add(edge);
  }
  /**
   * @return the iterative incoming edges inside the DAG.
   */
  public Map<IRVertex, Set<IREdge>> getIterativeIncomingEdges() {
    return this.iterativeIncomingEdges;
  }

  /**
   * Adds a non-iterative incoming edge, from outside the previous iteration.
   * @param edge edge to add.
   */
  public void addNonIterativeIncomingEdge(final IREdge edge) {
    this.nonIterativeIncomingEdges.putIfAbsent(edge.getDst(), new HashSet<>());
    this.nonIterativeIncomingEdges.get(edge.getDst()).add(edge);
  }

  /**
   * @return the non-iterative incoming edges of the LoopVertex.
   */
  public Map<IRVertex, Set<IREdge>> getNonIterativeIncomingEdges() {
    return this.nonIterativeIncomingEdges;
  }

  /**
   * Adds and outgoing edge of the contained DAG.
   * @param edge edge to add.
   */
  public void addDagOutgoingEdge(final IREdge edge) {
    this.dagOutgoingEdges.putIfAbsent(edge.getSrc(), new HashSet<>());
    this.dagOutgoingEdges.get(edge.getSrc()).add(edge);
  }
  /**
   * @return outgoing edges of the contained DAG.
   */
  public Map<IRVertex, Set<IREdge>> getDagOutgoingEdges() {
    return this.dagOutgoingEdges;
  }

  /**
   * Marks duplicate edges with DuplicateEdgeGroupProperty.
   */
  public void markDuplicateEdges() {
    nonIterativeIncomingEdges.forEach(((irVertex, irEdges) -> irEdges.forEach(irEdge -> {
      irEdge.setProperty(
          DuplicateEdgeGroupProperty.of(new DuplicateEdgeGroupPropertyValue(String.valueOf(duplicateEdgeGroupId))));
      duplicateEdgeGroupId++;
    })));
  }

  /**
   * Method for unrolling an iteration of the LoopVertex.
   * @param dagBuilder DAGBuilder to add the unrolled iteration to.
   * @return a LoopVertex with one less maximum iteration.
   */
  public LoopVertex unRollIteration(final DAGBuilder<IRVertex, IREdge> dagBuilder) {
    final HashMap<IRVertex, IRVertex> originalToNewIRVertex = new HashMap<>();
    final DAG<IRVertex, IREdge> dagToAdd = getDAG();

    decreaseMaxNumberOfIterations();

    // add the DAG and internal edges to the dagBuilder.
    dagToAdd.topologicalDo(irVertex -> {
      final IRVertex newIrVertex = irVertex.getClone();
      originalToNewIRVertex.putIfAbsent(irVertex, newIrVertex);

      dagBuilder.addVertex(newIrVertex, dagToAdd);
      dagToAdd.getIncomingEdgesOf(irVertex).forEach(edge -> {
        final IRVertex newSrc = originalToNewIRVertex.get(edge.getSrc());
        final IREdge newIrEdge = new IREdge(edge.getProperty(ExecutionProperty.Key.DataCommunicationPattern),
            newSrc, newIrVertex, edge.getCoder(), edge.isSideInput());
        edge.copyExecutionPropertiesTo(newIrEdge);
        dagBuilder.connectVertices(newIrEdge);
      });
    });

    // process DAG incoming edges.
    getDagIncomingEdges().forEach((dstVertex, irEdges) -> irEdges.forEach(edge -> {
      final IREdge newIrEdge = new IREdge(edge.getProperty(ExecutionProperty.Key.DataCommunicationPattern),
          edge.getSrc(), originalToNewIRVertex.get(dstVertex), edge.getCoder(), edge.isSideInput());
      edge.copyExecutionPropertiesTo(newIrEdge);
      dagBuilder.connectVertices(newIrEdge);
    }));

    if (loopTerminationConditionMet()) {
      // if termination condition met, we process the DAG outgoing edge.
      getDagOutgoingEdges().forEach((srcVertex, irEdges) -> irEdges.forEach(edge -> {
        final IREdge newIrEdge = new IREdge(edge.getProperty(ExecutionProperty.Key.DataCommunicationPattern),
            originalToNewIRVertex.get(srcVertex), edge.getDst(), edge.getCoder(), edge.isSideInput());
        edge.copyExecutionPropertiesTo(newIrEdge);
        dagBuilder.addVertex(edge.getDst()).connectVertices(newIrEdge);
      }));
    }

    // process next iteration's DAG incoming edges
    this.getDagIncomingEdges().clear();
    this.nonIterativeIncomingEdges.forEach((dstVertex, irEdges) -> irEdges.forEach(this::addDagIncomingEdge));
    this.iterativeIncomingEdges.forEach((dstVertex, irEdges) -> irEdges.forEach(edge -> {
      final IREdge newIrEdge = new IREdge(edge.getProperty(ExecutionProperty.Key.DataCommunicationPattern),
          originalToNewIRVertex.get(edge.getSrc()), dstVertex, edge.getCoder(), edge.isSideInput());
      edge.copyExecutionPropertiesTo(newIrEdge);
      this.addDagIncomingEdge(newIrEdge);
    }));

    return this;
  }

  /**
   * @return whether or not the loop termination condition has been met.
   */
  public Boolean loopTerminationConditionMet() {
    return loopTerminationConditionMet(maxNumberOfIterations);
  }
  /**
   * @param intPredicateInput input for the intPredicate of the loop termination condition.
   * @return whether or not the loop termination condition has been met.
   */
  public Boolean loopTerminationConditionMet(final Integer intPredicateInput) {
    return maxNumberOfIterations <= 0 || terminationCondition.test(intPredicateInput);
  }

  /**
   * Set the maximum number of iterations.
   * @param maxNum maximum number of iterations.
   */
  public void setMaxNumberOfIterations(final Integer maxNum) {
    this.maxNumberOfIterations = maxNum;
  }
  /**
   * @return termination condition int predicate.
   */
  public IntPredicate getTerminationCondition() {
    return terminationCondition;
  }
  /**
   * @return maximum number of iterations.
   */
  public Integer getMaxNumberOfIterations() {
    return this.maxNumberOfIterations;
  }
  /**
   * increase the value of maximum number of iterations by 1.
   */
  public void increaseMaxNumberOfIterations() {
    this.maxNumberOfIterations++;
  }
  /**
   * decrease the value of maximum number of iterations by 1.
   */
  private void decreaseMaxNumberOfIterations() {
    this.maxNumberOfIterations--;
  }

  /**
   * Set the intPredicate termination condition for the LoopVertex.
   * @param terminationCondition the termination condition to set.
   */
  public void setTerminationCondition(final IntPredicate terminationCondition) {
    this.terminationCondition = terminationCondition;
  }

  @Override
  public String propertiesToJSON() {
    final List<String> edgeMappings = edgeWithLoopToEdgeWithInternalVertex.entrySet().stream()
        .map(entry -> String.format("\"%s\": \"%s\"", entry.getKey().getId(), entry.getValue().getId()))
        .collect(Collectors.toList());
    final StringBuilder sb = new StringBuilder();
    sb.append("{");
    sb.append(irVertexPropertiesToString());
    sb.append(", \"remainingIteration\": ");
    sb.append(this.maxNumberOfIterations);
    sb.append(", \"DAG\": ");
    sb.append(getDAG());
    sb.append(", \"dagIncomingEdges\": ").append(crossingEdgesToJSON(dagIncomingEdges));
    sb.append(", \"dagOutgoingEdges\": ").append(crossingEdgesToJSON(dagOutgoingEdges));
    sb.append(", \"edgeWithLoopToEdgeWithInternalVertex\": {").append(String.join(", ", edgeMappings));
    sb.append("}}");
    return sb.toString();
  }

  /**
   * Convert the crossing edges to JSON.
   * @param map map of the crossing edges.
   * @return a string of JSON showing the crossing edges.
   */
  private static String crossingEdgesToJSON(final Map<IRVertex, Set<IREdge>> map) {
    final ArrayList<String> vertices = new ArrayList<>();
    map.forEach(((irVertex, irEdges) -> {
      final StringBuilder sb = new StringBuilder();
      sb.append("\"").append(irVertex.getId()).append("\": [");
      final List<String> edges = irEdges.stream().map(e -> "\"" + e.getId() + "\"").collect(Collectors.toList());
      sb.append(String.join(", ", edges)).append("]");
      vertices.add(sb.toString());
    }));
    return "{" + String.join(", ", vertices) + "}";
  }
}