package embedding;


import org.deeplearning4j.graph.api.NoEdgeHandling;
import org.deeplearning4j.graph.api.Vertex;
import org.deeplearning4j.graph.models.deepwalk.DeepWalk;
import org.deeplearning4j.graph.models.deepwalk.GraphHuffman;
import org.deeplearning4j.graph.models.embeddings.InMemoryGraphLookupTable;
import org.neo4j.collection.primitive.PrimitiveIntIterator;
import org.neo4j.graphalgo.api.Graph;
import org.neo4j.graphalgo.api.GraphFactory;
import org.neo4j.graphalgo.core.GraphLoader;
import org.neo4j.graphalgo.core.ProcedureConfiguration;
import org.neo4j.graphalgo.core.utils.Pools;
import org.neo4j.graphalgo.core.utils.ProgressTimer;
import org.neo4j.graphalgo.core.utils.TerminationFlag;
import org.neo4j.graphalgo.core.utils.paged.AllocationTracker;
import org.neo4j.graphalgo.core.write.Exporter;
import org.neo4j.graphalgo.results.PageRankScore;
import org.neo4j.graphdb.Direction;
import org.neo4j.kernel.api.KernelTransaction;
import org.neo4j.kernel.internal.GraphDatabaseAPI;
import org.neo4j.logging.Log;
import org.neo4j.procedure.*;
import org.neo4j.values.storable.DeepWalkPropertyTranslator;
import org.neo4j.values.storable.LookupTablePropertyTranslator;

import java.util.*;
import java.util.stream.IntStream;
import java.util.stream.Stream;

public class DeepWalkProc {

    @Context
    public GraphDatabaseAPI api;

    @Context
    public Log log;

    @Context
    public KernelTransaction transaction;

    @Procedure(value = "embedding.deepWalk", mode = Mode.WRITE)
    @Description("CALL embedding.deepWalk(label:String, relationship:String, " +
            "{graph: 'heavy/cypher', vectorSize:10, windowSize:2, learningRate:0.01 concurrency:4, direction:'BOTH}) " +
            "YIELD nodes, iterations, loadMillis, computeMillis, writeMillis, dampingFactor, write, writeProperty" +
            " - calculates page rank and potentially writes back")
    public Stream<PageRankScore.Stats> deepWalk2(
            @Name(value = "label", defaultValue = "") String label,
            @Name(value = "relationship", defaultValue = "") String relationship,
            @Name(value = "config", defaultValue = "{}") Map<String, Object> config) {
        ProcedureConfiguration configuration = ProcedureConfiguration.create(config);
        AllocationTracker tracker = AllocationTracker.create();

        PageRankScore.Stats.Builder statsBuilder = new PageRankScore.Stats.Builder();
        final Graph graph = load(label, relationship, tracker, configuration.getGraphImpl(), statsBuilder, configuration);

        int nodeCount = Math.toIntExact(graph.nodeCount());
        if (nodeCount == 0) {
            graph.release();
            return Stream.empty();
        }

        InMemoryGraphLookupTable lookupTable = runDeepWalk2(graph, statsBuilder, configuration);

        if (configuration.isWriteFlag()) {
            final String writeProperty = configuration.getWriteProperty("deepWalk");
            statsBuilder.timeWrite(() -> Exporter.of(api, graph)
                    .withLog(log)
                    .parallel(Pools.DEFAULT, configuration.getConcurrency(), TerminationFlag.wrap(transaction))
                    .build()
                    .write(
                            writeProperty,
                            lookupTable,
                            new LookupTablePropertyTranslator()
                    )
            );
        }

        return Stream.of(statsBuilder.build());
    }


    @Procedure(name = "embedding.deepWalk.stream", mode = Mode.READ)
    @Description("CALL embedding.deepWalk.stream(label:String, relationship:String, {graph: 'heavy/cypher', walkLength:10, vectorSize:10, windowSize:2, learningRate:0.01 concurrency:4, direction:'BOTH'}) " +
            "YIELD nodeId, embedding - compute embeddings for each node")
    public Stream<DeepWalkResult> deepWalkStream2(
            @Name(value = "label", defaultValue = "") String label,
            @Name(value = "relationship", defaultValue = "") String relationship,
            @Name(value = "config", defaultValue = "{}") Map<String, Object> config) {

        ProcedureConfiguration configuration = ProcedureConfiguration.create(config);
        AllocationTracker tracker = AllocationTracker.create();

        PageRankScore.Stats.Builder statsBuilder = new PageRankScore.Stats.Builder();
        final Graph graph = load(label, relationship, tracker, configuration.getGraphImpl(), statsBuilder, configuration);

        int nodeCount = Math.toIntExact(graph.nodeCount());
        if (nodeCount == 0) {
            graph.release();
            return Stream.empty();
        }

        InMemoryGraphLookupTable lookupTable = runDeepWalk2(graph, statsBuilder, configuration);

        return IntStream.range(0, (int) graph.nodeCount()).mapToObj(index ->
                new DeepWalkResult(graph.toOriginalNodeId(index), lookupTable.getVector(index).toDoubleVector()));
    }

    private InMemoryGraphLookupTable runDeepWalk2(Graph graph, PageRankScore.Stats.Builder statsBuilder,
                                                  ProcedureConfiguration configuration) {
        long vectorSize = configuration.get("vectorSize", 10L);
        double learningRate = configuration.get("learningRate", 0.01);
        long windowSize = configuration.get("windowSize", 2L);
        long walkLength = configuration.get("walkSize", 10L);
        long numberOfWalks = configuration.get("numberOfWalks", 10L);

        Map<String, Number> params = new HashMap<>();
        params.put("vectorSize", vectorSize);
        params.put("learningRate", learningRate);
        params.put("windowSize", windowSize);
        params.put("walkLength", walkLength);
        params.put("numberOfWalks", numberOfWalks);

        log.info("Executing DeepWalk with params: %s", params);

        GraphHuffman gh = new GraphHuffman((int) graph.nodeCount());

        int[] degrees = new int[(int) graph.nodeCount()];
        graph.forEachNode(nodeId -> {
            degrees[nodeId] = graph.degree(nodeId, Direction.BOTH);
            return true;
        });

        gh.buildTree(degrees);
        InMemoryGraphLookupTable lookupTable = new InMemoryGraphLookupTable((int) graph.nodeCount(),
                (int) vectorSize, gh, learningRate);

        NodeWalker nodeWalker = new NodeWalker();
        NodeWalker.RandomNextNodeStrategy strategy = new NodeWalker.RandomNextNodeStrategy(graph, graph);

//        int limit = (((int) numberOfWalks) == -1) ? (int) graph.nodeCount() : Math.toIntExact(numberOfWalks);
        int limit = Math.toIntExact(graph.nodeCount());
        IntStream idStream = IntStream.range(0, Math.toIntExact(graph.nodeCount())).limit(limit);

        PrimitiveIterator.OfInt ints = IntStream.range(0, limit).unordered().parallel().flatMap((s) ->
                idStream).limit(limit).iterator();

        Stream<int[]> randomWalks = nodeWalker.internalRandomWalk((int) walkLength, strategy,
                TerminationFlag.wrap(transaction), 1, (int) numberOfWalks, ints);

        randomWalks.forEach(walk -> skipGram(walk, lookupTable, (int) windowSize));

        return lookupTable;
    }

    private void skipGram(int[] walk, InMemoryGraphLookupTable lookupTable, int windowSize) {
        for (int mid = windowSize; mid < walk.length - windowSize; mid++) {
            for (int pos = mid - windowSize; pos <= mid + windowSize; pos++) {
                if (pos == mid)
                    continue;

                lookupTable.iterate(walk[mid], walk[pos]);
            }
        }
    }


    @Procedure(value = "embedding.dl4j.deepWalk", mode = Mode.WRITE)
    @Description("CALL embedding.dl4j.deepWalk(label:String, relationship:String, " +
            "{graph: 'heavy/cypher', vectorSize:10, windowSize:2, learningRate:0.01 concurrency:4, direction:'BOTH}) " +
            "YIELD nodes, iterations, loadMillis, computeMillis, writeMillis, dampingFactor, write, writeProperty" +
            " - calculates page rank and potentially writes back")
    public Stream<PageRankScore.Stats> deepWalk(
            @Name(value = "label", defaultValue = "") String label,
            @Name(value = "relationship", defaultValue = "") String relationship,
            @Name(value = "config", defaultValue = "{}") Map<String, Object> config) {
        ProcedureConfiguration configuration = ProcedureConfiguration.create(config);
        AllocationTracker tracker = AllocationTracker.create();

        PageRankScore.Stats.Builder statsBuilder = new PageRankScore.Stats.Builder();
        final Graph graph = load(label, relationship, tracker, configuration.getGraphImpl(), statsBuilder, configuration);

        int nodeCount = Math.toIntExact(graph.nodeCount());
        if (nodeCount == 0) {
            graph.release();
            return Stream.empty();
        }

        org.deeplearning4j.graph.graph.Graph<Integer, Integer> iGraph = buildDl4jGraph(graph);
        DeepWalk<Integer, Integer> dw = runDeepWalk(iGraph, statsBuilder, configuration);

        if (configuration.isWriteFlag()) {
            final String writeProperty = configuration.getWriteProperty("deepWalk");
            statsBuilder.timeWrite(() -> Exporter.of(api, graph)
                    .withLog(log)
                    .parallel(Pools.DEFAULT, configuration.getConcurrency(), TerminationFlag.wrap(transaction))
                    .build()
                    .write(
                            writeProperty,
                            dw,
                            new DeepWalkPropertyTranslator()
                    )
            );
        }

        return Stream.of(statsBuilder.build());
    }


    @Procedure(name = "embedding.dl4j.deepWalk.stream", mode = Mode.READ)
    @Description("CALL embedding.dl4j.deepWalk.stream(label:String, relationship:String, {graph: 'heavy/cypher', walkLength:10, vectorSize:10, windowSize:2, learningRate:0.01 concurrency:4, direction:'BOTH'}) " +
            "YIELD nodeId, embedding - compute embeddings for each node")
    public Stream<DeepWalkResult> deepWalkStream(
            @Name(value = "label", defaultValue = "") String label,
            @Name(value = "relationship", defaultValue = "") String relationship,
            @Name(value = "config", defaultValue = "{}") Map<String, Object> config) {

        ProcedureConfiguration configuration = ProcedureConfiguration.create(config);
        AllocationTracker tracker = AllocationTracker.create();

        PageRankScore.Stats.Builder statsBuilder = new PageRankScore.Stats.Builder();
        final Graph graph = load(label, relationship, tracker, configuration.getGraphImpl(), statsBuilder, configuration);

        int nodeCount = Math.toIntExact(graph.nodeCount());
        if (nodeCount == 0) {
            graph.release();
            return Stream.empty();
        }

        org.deeplearning4j.graph.graph.Graph<Integer, Integer> iGraph = buildDl4jGraph(graph);
        DeepWalk<Integer, Integer> dw = runDeepWalk(iGraph, statsBuilder, configuration);

        return IntStream.range(0, dw.numVertices()).mapToObj(index ->
                new DeepWalkResult(graph.toOriginalNodeId(index), dw.getVertexVector(index).toDoubleVector()));
    }

    private org.deeplearning4j.graph.graph.Graph<Integer, Integer> buildDl4jGraph(Graph graph) {
        List<Vertex<Integer>> nodes = new ArrayList<>();

        PrimitiveIntIterator nodeIterator = graph.nodeIterator();
        while(nodeIterator.hasNext()) {
            int nodeId = nodeIterator.next();
            nodes.add(new Vertex<>(nodeId,nodeId));
        }

        org.deeplearning4j.graph.graph.Graph<Integer, Integer> iGraph = new org.deeplearning4j.graph.graph.Graph<>(nodes);

        nodeIterator = graph.nodeIterator();
        while(nodeIterator.hasNext()) {
            int nodeId = nodeIterator.next();
            graph.forEachRelationship(nodeId, Direction.BOTH, (sourceNodeId, targetNodeId, relationId) -> {
                iGraph.addEdge(nodeId, targetNodeId, -1, false);
                return false;
            });
        }
        return iGraph;
    }

    private DeepWalk<Integer, Integer> runDeepWalk(org.deeplearning4j.graph.graph.Graph<Integer, Integer> iGraph,
                                                   PageRankScore.Stats.Builder statsBuilder, ProcedureConfiguration configuration) {
        long vectorSize = configuration.get("vectorSize", 10L);
        double learningRate = configuration.get("learningRate", 0.01);
        long  windowSize = configuration.get("windowSize", 2L);
        long walkLength = configuration.get("walkSize", 10L);
        long numberOfWalks = configuration.get("numberOfWalks", 10L);

        Map<String, Number> params = new HashMap<>();
        params.put("vectorSize", vectorSize);
        params.put("learningRate", learningRate);
        params.put("windowSize", windowSize);
        params.put("walkLength", walkLength);
        params.put("numberOfWalks", numberOfWalks);

        log.info("Executing DeepWalk with params: %s", params);


        DeepWalk.Builder<Integer, Integer> builder = new DeepWalk.Builder<>();
        builder.vectorSize((int) vectorSize);
        builder.learningRate(learningRate);
        builder.windowSize((int) windowSize);
        DeepWalk<Integer, Integer> dw = builder.build();

        dw.initialize(iGraph);

        statsBuilder.timeEval(() -> dw.fit(new MyRandomWalkGraphIteratorProvider<>(
                iGraph, (int) walkLength, 1,
                NoEdgeHandling.SELF_LOOP_ON_DISCONNECTED, (int) numberOfWalks)));

        return dw;
    }


    private Graph load(
            String label,
            String relationship,
            AllocationTracker tracker,
            Class<? extends GraphFactory> graphFactory,
            PageRankScore.Stats.Builder statsBuilder, ProcedureConfiguration configuration) {
        GraphLoader graphLoader = new GraphLoader(api, Pools.DEFAULT)
                .init(log, label, relationship, configuration)
                .withAllocationTracker(tracker)
                .withDirection(configuration.getDirection(Direction.BOTH))
                .withoutNodeProperties()
                .withoutNodeWeights()
                .withoutRelationshipWeights();


        try (ProgressTimer timer = ProgressTimer.start()) {
            Graph graph = graphLoader.load(graphFactory);
            statsBuilder.withNodes(graph.nodeCount());
            return graph;
        }
    }

}