/* * Copyright (C) 2020 Grakn Labs * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as * published by the Free Software Foundation, either version 3 of the * License, or (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see <https://www.gnu.org/licenses/>. * */ package grakn.core.kb.graql.planning; import com.google.common.collect.ImmutableMap; import grakn.core.common.util.Partition; import grakn.core.kb.graql.planning.spanningtree.graph.DirectedEdge; import grakn.core.kb.graql.planning.spanningtree.graph.Node; import grakn.core.kb.graql.planning.spanningtree.graph.WeightedGraph; import grakn.core.kb.graql.planning.spanningtree.util.Weighted; import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Deque; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; import static com.google.common.base.Predicates.and; import static com.google.common.base.Predicates.not; import static grakn.core.kb.graql.planning.spanningtree.util.Weighted.weighted; /** * Chu-Liu-Edmonds' algorithm for finding a maximum branching in a complete, directed graph in O(n^2) time. * Implementation is based on Tarjan's "Finding Optimum Branchings" paper: * http://cw.felk.cvut.cz/lib/exe/fetch.php/courses/a4m33pal/cviceni/tarjan-finding-optimum-branchings.pdf * */ public class ChuLiuEdmonds { /** * Represents the subgraph that gets iteratively built up in the CLE algorithm. */ static class PartialSolution { // Partition representing the strongly connected components (SCCs). private final Partition<Node> stronglyConnected; // Partition representing the weakly connected components (WCCs). private final Partition<Node> weaklyConnected; // An invariant of the CLE algorithm is that each SCC always has at most one incoming edge. // You can think of these edges as implicitly defining a graph with SCCs as nodes. private final Map<Node, Weighted<DirectedEdge>> incomingEdgeByScc; // History of edges we've added, and for each, a list of edges it would exclude. // More recently added edges get priority over less recently added edges when reconstructing the final tree. private final Deque<ExclusiveEdge> edgesAndWhatTheyExclude; // a priority queue of incoming edges for each SCC that we haven't chosen an incoming edge for yet. final EdgeQueueMap unseenIncomingEdges; // running sum of weights. // edge weights are adjusted as we go to take into account the fact that we have an extra edge in each cycle private double score; private PartialSolution(Partition<Node> stronglyConnected, Partition<Node> weaklyConnected, Map<Node, Weighted<DirectedEdge>> incomingEdgeByScc, Deque<ExclusiveEdge> edgesAndWhatTheyExclude, EdgeQueueMap unseenIncomingEdges, double score) { this.stronglyConnected = stronglyConnected; this.weaklyConnected = weaklyConnected; this.incomingEdgeByScc = incomingEdgeByScc; this.edgesAndWhatTheyExclude = edgesAndWhatTheyExclude; this.unseenIncomingEdges = unseenIncomingEdges; this.score = score; } public static PartialSolution initialize(WeightedGraph graph) { final Partition<Node> stronglyConnected = Partition.singletons(graph.getNodes()); final HashMap<Node, Weighted<DirectedEdge>> incomingByScc = new HashMap<>(); final Deque<ExclusiveEdge> exclusiveEdges = new ArrayDeque<>(); // group edges by their destination component final EdgeQueueMap incomingEdges = new EdgeQueueMap(stronglyConnected); for (Node destinationNode : graph.getNodes()) { for (Weighted<DirectedEdge> inEdge : graph.getIncomingEdges(destinationNode)) { if (inEdge.weight != Double.NEGATIVE_INFINITY) { incomingEdges.addEdge(inEdge); } } } return new PartialSolution( stronglyConnected, Partition.singletons(graph.getNodes()), incomingByScc, exclusiveEdges, incomingEdges, 0.0 ); } public Set<Node> getNodes() { return stronglyConnected.getNodes(); } /** * Given an edge that completes a cycle, merge all SCCs on that cycle into one SCC. * Returns the new component. */ private Node merge(Weighted<DirectedEdge> newEdge, EdgeQueueMap unseenIncomingEdges) { // Find edges connecting SCCs on the path from newEdge.destination to newEdge.source final List<Weighted<DirectedEdge>> cycle = getCycle(newEdge); // build up list of queues that need to be merged, with the edge they would exclude final List<EdgeQueueMap.QueueAndReplace> queuesToMerge = new ArrayList<>(cycle.size()); for (Weighted<DirectedEdge> currentEdge : cycle) { final Node destination = stronglyConnected.componentOf(currentEdge.val.destination); final EdgeQueueMap.EdgeQueue queue = unseenIncomingEdges.queueByDestination.get(destination); // if we choose an edge in `queue`, we'll have to throw out `currentEdge` at the end // (each SCC can have only one incoming edge). queuesToMerge.add(EdgeQueueMap.QueueAndReplace.of(queue, currentEdge)); unseenIncomingEdges.queueByDestination.remove(destination); } // Merge all SCCs on the cycle into one for (Weighted<DirectedEdge> e : cycle) { stronglyConnected.merge(e.val.source, e.val.destination); } Node component = stronglyConnected.componentOf(newEdge.val.destination); // merge the queues and put the merged queue back into our map under the new component unseenIncomingEdges.merge(component, queuesToMerge); // keep our implicit graph of SCCs up to date: // we just created a cycle, so all in-edges have sources inside the new component // i.e. there is no edge with source outside component, and destination inside component incomingEdgeByScc.remove(component); return component; } /** * Gets the cycle of edges between SCCs that newEdge creates */ private List<Weighted<DirectedEdge>> getCycle(Weighted<DirectedEdge> newEdge) { final List<Weighted<DirectedEdge>> cycle = new ArrayList<>(); // circle around backward until you get back to where you started Weighted<DirectedEdge> edge = newEdge; cycle.add(edge); while (!stronglyConnected.sameComponent(edge.val.source, newEdge.val.destination)) { edge = incomingEdgeByScc.get(stronglyConnected.componentOf(edge.val.source)); cycle.add(edge); } return cycle; } /** * Adds the given edge to this subgraph, merging SCCs if necessary * * @return the new SCC if adding edge created a cycle */ public Optional<Node> addEdge(ExclusiveEdge wEdgeAndExcludes) { final DirectedEdge edge = wEdgeAndExcludes.edge; final double weight = wEdgeAndExcludes.weight; final Weighted<DirectedEdge> wEdge = weighted(edge, weight); score += weight; final Node destinationScc = stronglyConnected.componentOf(edge.destination); edgesAndWhatTheyExclude.addFirst(wEdgeAndExcludes); incomingEdgeByScc.put(destinationScc, wEdge); if (!weaklyConnected.sameComponent(edge.source, edge.destination)) { // Edge connects two different WCCs. Including it won't create a new cycle weaklyConnected.merge(edge.source, edge.destination); return Optional.empty(); } else { // Edge is contained within one WCC. Including it will create a new cycle. return Optional.of(merge(wEdge, unseenIncomingEdges)); } } /** * Recovers the optimal arborescence. * <p> * Each SCC can only have 1 edge entering it: the edge that we added most recently. * So we work backwards, adding edges unless they conflict with edges we've already added. * Runtime is O(n^2) in the worst case. */ private Weighted<Arborescence<Node>> recoverBestArborescence() { final ImmutableMap.Builder<Node, Node> parents = ImmutableMap.builder(); final Set<DirectedEdge> excluded = new HashSet<>(); // start with the most recent while (!edgesAndWhatTheyExclude.isEmpty()) { final ExclusiveEdge edgeAndWhatItExcludes = edgesAndWhatTheyExclude.pollFirst(); final DirectedEdge edge = edgeAndWhatItExcludes.edge; if (!excluded.contains(edge)) { excluded.addAll(edgeAndWhatItExcludes.excluded); parents.put(edge.destination, edge.source); } } return weighted(Arborescence.of(parents.build()), score); } public Optional<ExclusiveEdge> popBestEdge(Node component) { return popBestEdge(component, Arborescence.empty()); } /** * Always breaks ties in favor of edges in `best` */ public Optional<ExclusiveEdge> popBestEdge(Node component, Arborescence<Node> best) { return unseenIncomingEdges.popBestEdge(component, best); } } /** * Find an optimal arborescence of the given graph `graph`, rooted in the given node `root`. */ public static Weighted<Arborescence<Node>> getMaxArborescence(WeightedGraph graph, Node root) { // remove all edges incoming to `root`. resulting arborescence is then forced to be rooted at `root`. return getMaxArborescence(graph.filterEdges(not(DirectedEdge.hasDestination(root)))); } public static Weighted<Arborescence<Node>> getMaxArborescence(WeightedGraph graph, Set<DirectedEdge> required, Set<DirectedEdge> banned) { return getMaxArborescence(graph.filterEdges(and(not(DirectedEdge.competesWith(required)), not(DirectedEdge.isIn(banned))))); } /** * Find an optimal arborescence of the given graph. */ public static Weighted<Arborescence<Node>> getMaxArborescence(WeightedGraph graph) { final PartialSolution partialSolution = PartialSolution.initialize(graph.filterEdges(not(DirectedEdge.isAutoCycle()))); // In the beginning, subgraph has no edges, so no SCC has in-edges. final Deque<Node> componentsWithNoInEdges = new ArrayDeque<>(partialSolution.getNodes()); // Work our way through all componentsWithNoInEdges, in no particular order while (!componentsWithNoInEdges.isEmpty()) { final Node component = componentsWithNoInEdges.poll(); // find maximum edge entering 'component' from outside 'component'. final Optional<ExclusiveEdge> oMaxInEdge = partialSolution.popBestEdge(component); if (!oMaxInEdge.isPresent()) continue; // No in-edges left to consider for this component. Done with it! final ExclusiveEdge maxInEdge = oMaxInEdge.get(); // add the new edge to subgraph, merging SCCs if necessary final Optional<Node> newComponent = partialSolution.addEdge(maxInEdge); if (newComponent.isPresent()) { // addEdge created a cycle/component, which means the new component doesn't have any incoming edges componentsWithNoInEdges.add(newComponent.get()); } } // Once no component has incoming edges left to consider, it's time to recover the optimal branching. return partialSolution.recoverBestArborescence(); } }