package com.maxdemarzi.processing.unionfind;

import com.maxdemarzi.processing.NodeCounter;
import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap;
import it.unimi.dsi.fastutil.longs.Long2LongMap;
import it.unimi.dsi.fastutil.longs.Long2LongOpenHashMap;
import org.neo4j.graphdb.*;

/*
    Weighted quick-union with path compression
    See https://www.cs.princeton.edu/~rs/AlgsDS07/01UnionFind.pdf
 */

public class UnionFindMapStorage implements UnionFind {
    private final GraphDatabaseService db;
    private final long nodes;
    private Long2IntOpenHashMap rankMap;
    private Long2LongMap rootMap;

    public UnionFindMapStorage(GraphDatabaseService db) {
        this.db = db;
        this.rootMap = new Long2LongOpenHashMap();
        this.rankMap = new Long2IntOpenHashMap();
        this.nodes = new NodeCounter().getNodeCount(db);
    }

    @Override
    public void compute(String label, String type, int iterations) {
        RelationshipType relationshipType = RelationshipType.withName(type);

        try ( Transaction tx = db.beginTx()) {
            ResourceIterator<Node> nodes = db.findNodes(DynamicLabel.label(label));
            while (nodes.hasNext()) {
                long nodeId = nodes.next().getId();
                rootMap.put(nodeId, nodeId);
                rankMap.put(nodeId, 1);
            }

            for( Relationship relationship : db.getAllRelationships()) {
                if (relationship.isType(relationshipType)) {

                    long x = relationship.getStartNode().getId();
                    long y = relationship.getEndNode().getId();
                    if (x == y) { continue; }

                    x = rootMap.get(relationship.getStartNode().getId());
                    y = rootMap.get(relationship.getEndNode().getId());

                    while (x != rootMap.get(x)) {
                        rootMap.put(x, rootMap.get(rootMap.get(x)));
                        x = rootMap.get(x);
                    }
                    while (y != rootMap.get(y)) {
                        rootMap.put(y, rootMap.get(rootMap.get(y)));
                        y = rootMap.get(y);
                    }

                    if (x != y) {
                        if ( rankMap.get(x) > rankMap.get(y)) {
                            rootMap.put(y, x);
                        } else if (rankMap.get(x) < rankMap.get(y)) {
                            rootMap.put(x, y);
                        } else {
                            rootMap.put(y, x);
                            rankMap.put(x, rankMap.get(x) + 1);
                        }
                    }
                }
            }

            // This part is technically not necessary since we can just follow the unionfind property
            // of any node UP the tree to see if it's really connected or not.

            int iteration = 0;
            boolean done = false;

            while (!done && iterations > 0) {
                done = true;
                iteration++;
                nodes = db.findNodes(DynamicLabel.label(label));
                while (nodes.hasNext()) {
                    long x = nodes.next().getId();
                    if (rootMap.get(x) != x) {
                        done = false;
                        // This can be changed to be the GrandParent instead of Parent for faster convergence.
                        rootMap.put(x, rootMap.get(rootMap.get(x)));
                    }
                }

                if (iteration > iterations) {
                    done = true;
                }
            }
        }
    }

    @Override
    public double getResult(long node) {
        return rootMap != null ? rootMap.getOrDefault(node, -1L) : -1;
    }

    @Override
    public long numberOfNodes() {
        return nodes;
    };
}