package grakn.core.kb.graql.planning;

import grakn.core.common.util.Partition;
import grakn.core.kb.graql.planning.spanningtree.graph.DirectedEdge;
import grakn.core.kb.graql.planning.spanningtree.graph.Node;
import grakn.core.kb.graql.planning.spanningtree.graph.WeightedGraph;
import grakn.core.kb.graql.planning.spanningtree.util.Weighted;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import static grakn.core.kb.graql.planning.spanningtree.util.Weighted.weighted;

/**
* Chu-Liu-Edmonds' algorithm for finding a maximum branching in a complete, directed graph in O(n^2) time.
* Implementation is based on Tarjan's "Finding Optimum Branchings" paper:
* http://cw.felk.cvut.cz/lib/exe/fetch.php/courses/a4m33pal/cviceni/tarjan-finding-optimum-branchings.pdf
*
*/
public class ChuLiuEdmonds {
/**
* Represents the subgraph that gets iteratively built up in the CLE algorithm.
*/
static class PartialSolution {
// Partition representing the strongly connected components (SCCs).
private final Partition<Node> stronglyConnected;
// Partition representing the weakly connected components (WCCs).
private final Partition<Node> weaklyConnected;
// An invariant of the CLE algorithm is that each SCC always has at most one incoming edge.
// You can think of these edges as implicitly defining a graph with SCCs as nodes.
private final Map<Node, Weighted<DirectedEdge>> incomingEdgeByScc;
// History of edges we've added, and for each, a list of edges it would exclude.
// More recently added edges get priority over less recently added edges when reconstructing the final tree.
private final Deque<ExclusiveEdge> edgesAndWhatTheyExclude;
// a priority queue of incoming edges for each SCC that we haven't chosen an incoming edge for yet.
final EdgeQueueMap unseenIncomingEdges;
// running sum of weights.
// edge weights are adjusted as we go to take into account the fact that we have an extra edge in each cycle
private double score;

private PartialSolution(Partition<Node> stronglyConnected,
Partition<Node> weaklyConnected,
Map<Node, Weighted<DirectedEdge>> incomingEdgeByScc,
Deque<ExclusiveEdge> edgesAndWhatTheyExclude,
EdgeQueueMap unseenIncomingEdges,
double score) {
this.stronglyConnected = stronglyConnected;
this.weaklyConnected = weaklyConnected;
this.incomingEdgeByScc = incomingEdgeByScc;
this.edgesAndWhatTheyExclude = edgesAndWhatTheyExclude;
this.unseenIncomingEdges = unseenIncomingEdges;
this.score = score;
}

public static PartialSolution initialize(WeightedGraph graph) {
final Partition<Node> stronglyConnected = Partition.singletons(graph.getNodes());
final HashMap<Node, Weighted<DirectedEdge>> incomingByScc = new HashMap<>();
final Deque<ExclusiveEdge> exclusiveEdges = new ArrayDeque<>();
// group edges by their destination component
final EdgeQueueMap incomingEdges = new EdgeQueueMap(stronglyConnected);
for (Node destinationNode : graph.getNodes()) {
for (Weighted<DirectedEdge> inEdge : graph.getIncomingEdges(destinationNode)) {
if (inEdge.weight != Double.NEGATIVE_INFINITY) {
}
}
}
return new PartialSolution(
stronglyConnected,
Partition.singletons(graph.getNodes()),
incomingByScc,
exclusiveEdges,
incomingEdges,
0.0
);
}

public Set<Node> getNodes() {
return stronglyConnected.getNodes();
}

/**
* Given an edge that completes a cycle, merge all SCCs on that cycle into one SCC.
* Returns the new component.
*/
private Node merge(Weighted<DirectedEdge> newEdge, EdgeQueueMap unseenIncomingEdges) {
// Find edges connecting SCCs on the path from newEdge.destination to newEdge.source
final List<Weighted<DirectedEdge>> cycle = getCycle(newEdge);
// build up list of queues that need to be merged, with the edge they would exclude
final List<EdgeQueueMap.QueueAndReplace> queuesToMerge = new ArrayList<>(cycle.size());
for (Weighted<DirectedEdge> currentEdge : cycle) {
final Node destination = stronglyConnected.componentOf(currentEdge.val.destination);
final EdgeQueueMap.EdgeQueue queue = unseenIncomingEdges.queueByDestination.get(destination);
// if we choose an edge in `queue`, we'll have to throw out `currentEdge` at the end
// (each SCC can have only one incoming edge).
unseenIncomingEdges.queueByDestination.remove(destination);
}
// Merge all SCCs on the cycle into one
for (Weighted<DirectedEdge> e : cycle) {
stronglyConnected.merge(e.val.source, e.val.destination);
}
Node component = stronglyConnected.componentOf(newEdge.val.destination);
// merge the queues and put the merged queue back into our map under the new component
unseenIncomingEdges.merge(component, queuesToMerge);
// keep our implicit graph of SCCs up to date:
// we just created a cycle, so all in-edges have sources inside the new component
// i.e. there is no edge with source outside component, and destination inside component
incomingEdgeByScc.remove(component);
return component;
}

/**
* Gets the cycle of edges between SCCs that newEdge creates
*/
private List<Weighted<DirectedEdge>> getCycle(Weighted<DirectedEdge> newEdge) {
final List<Weighted<DirectedEdge>> cycle = new ArrayList<>();
// circle around backward until you get back to where you started
Weighted<DirectedEdge> edge = newEdge;
while (!stronglyConnected.sameComponent(edge.val.source, newEdge.val.destination)) {
edge = incomingEdgeByScc.get(stronglyConnected.componentOf(edge.val.source));
}
return cycle;
}

/**
* Adds the given edge to this subgraph, merging SCCs if necessary
*
* @return the new SCC if adding edge created a cycle
*/
final DirectedEdge edge = wEdgeAndExcludes.edge;
final double weight = wEdgeAndExcludes.weight;
final Weighted<DirectedEdge> wEdge = weighted(edge, weight);
score += weight;
final Node destinationScc = stronglyConnected.componentOf(edge.destination);
incomingEdgeByScc.put(destinationScc, wEdge);
if (!weaklyConnected.sameComponent(edge.source, edge.destination)) {
// Edge connects two different WCCs. Including it won't create a new cycle
weaklyConnected.merge(edge.source, edge.destination);
return Optional.empty();
} else {
// Edge is contained within one WCC. Including it will create a new cycle.
return Optional.of(merge(wEdge, unseenIncomingEdges));
}
}

/**
* Recovers the optimal arborescence.
* <p>
* Each SCC can only have 1 edge entering it: the edge that we added most recently.
* Runtime is O(n^2) in the worst case.
*/
private Weighted<Arborescence<Node>> recoverBestArborescence() {
final ImmutableMap.Builder<Node, Node> parents = ImmutableMap.builder();
final Set<DirectedEdge> excluded = new HashSet<>();
while (!edgesAndWhatTheyExclude.isEmpty()) {
final ExclusiveEdge edgeAndWhatItExcludes = edgesAndWhatTheyExclude.pollFirst();
final DirectedEdge edge = edgeAndWhatItExcludes.edge;
if (!excluded.contains(edge)) {
parents.put(edge.destination, edge.source);
}
}
return weighted(Arborescence.of(parents.build()), score);
}

public Optional<ExclusiveEdge> popBestEdge(Node component) {
return popBestEdge(component, Arborescence.empty());
}

/**
* Always breaks ties in favor of edges in `best`
*/
public Optional<ExclusiveEdge> popBestEdge(Node component, Arborescence<Node> best) {
return unseenIncomingEdges.popBestEdge(component, best);
}
}

/**
* Find an optimal arborescence of the given graph `graph`, rooted in the given node `root`.
*/
public static Weighted<Arborescence<Node>> getMaxArborescence(WeightedGraph graph, Node root) {
// remove all edges incoming to `root`. resulting arborescence is then forced to be rooted at `root`.
return getMaxArborescence(graph.filterEdges(not(DirectedEdge.hasDestination(root))));
}

public static Weighted<Arborescence<Node>> getMaxArborescence(WeightedGraph graph,
Set<DirectedEdge> required,
Set<DirectedEdge> banned) {
return getMaxArborescence(graph.filterEdges(and(not(DirectedEdge.competesWith(required)), not(DirectedEdge.isIn(banned)))));
}

/**
* Find an optimal arborescence of the given graph.
*/
public static Weighted<Arborescence<Node>> getMaxArborescence(WeightedGraph graph) {
final PartialSolution partialSolution =
PartialSolution.initialize(graph.filterEdges(not(DirectedEdge.isAutoCycle())));
// In the beginning, subgraph has no edges, so no SCC has in-edges.
final Deque<Node> componentsWithNoInEdges = new ArrayDeque<>(partialSolution.getNodes());

// Work our way through all componentsWithNoInEdges, in no particular order
while (!componentsWithNoInEdges.isEmpty()) {
final Node component = componentsWithNoInEdges.poll();
// find maximum edge entering 'component' from outside 'component'.
final Optional<ExclusiveEdge> oMaxInEdge = partialSolution.popBestEdge(component);
if (!oMaxInEdge.isPresent()) continue; // No in-edges left to consider for this component. Done with it!
final ExclusiveEdge maxInEdge = oMaxInEdge.get();
// add the new edge to subgraph, merging SCCs if necessary