package macrobase.analysis.summary.itemset;

import static com.codahale.metrics.MetricRegistry.name;

import com.codahale.metrics.Timer;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;

import macrobase.MacroBase;
import macrobase.analysis.summary.itemset.result.ItemsetWithCount;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.*;
import java.util.stream.Collectors;

public class StreamingFPGrowth {
    private static final Logger log = LoggerFactory.getLogger(StreamingFPGrowth.class);
    private final Timer fpMine = MacroBase.metrics.timer(name(StreamingFPGrowth.class, "fpMine"));
    private final Timer restructureTree = MacroBase.metrics.timer(name(StreamingFPGrowth.class, "restructureTree"));
    private final Timer updateFrequentItemOrder = MacroBase.metrics.timer(
            name(StreamingFPGrowth.class, "updateFrequentItemOrder"));
    private final Timer insertFrequentItems = MacroBase.metrics.timer(
            name(StreamingFPGrowth.class, "insertFrequentItems"));

    StreamingFPTree fp = new StreamingFPTree();
    boolean needsRestructure = false;
    boolean startedStreaming = false;


    private final double support;

    public StreamingFPGrowth(double support) {
        this.support = support;
    }

    class StreamingFPTree {
        private FPTreeNode root = new FPTreeNode(-1, null, 0);
        // used to calculate the order
        private Map<Integer, Double> frequentItemCounts = new HashMap<>();

        // item order -- need canonical to break ties; 0 is smallest, N is largest
        private Map<Integer, Integer> frequentItemOrder = new HashMap<>();

        protected Map<Integer, FPTreeNode> nodeHeaders = new HashMap<>();

        protected Set<FPTreeNode> leafNodes = new HashSet<>();

        Set<FPTreeNode> sortedNodes = new HashSet<>();

        private void printTreeDebug() {
            log.debug("Frequent Item Counts:");
            frequentItemCounts.entrySet().forEach(e -> log.debug("{} {}", e.getKey(), e.getValue()));

            log.debug("Frequent Item Order:");
            frequentItemOrder.entrySet().forEach(
                    e -> log.debug("{} {}", e.getKey(), e.getValue()));


            walkTree(root, 1);
        }

        // todo: make more efficient
        private void decayWeights(FPTreeNode start, double decayWeight) {
            if (start == root) {
                for (Integer item : frequentItemCounts.keySet()) {
                    frequentItemCounts.put(item, frequentItemCounts.get(item) * decayWeight);
                }
            }

            start.count *= decayWeight;
            if (start.getChildren() != null) {
                for (FPTreeNode child : start.getChildren()) {
                    decayWeights(child, decayWeight);
                }
            }
        }


        private void walkTree(FPTreeNode start, int treeDepth) {
            log.debug("{} node: {}, count: {}, sorted: {}",
                      new String(new char[treeDepth]).replaceAll("\0", "\t"),
                      start.getItem(), start.getCount(),
                      sortedNodes.contains(start));
            if (start.getChildren() != null) {
                for (FPTreeNode child : start.getChildren()) {
                    walkTree(child, treeDepth + 1);
                }
            }
        }

        private class FPTreeNode {
            private int item;
            private double count;
            private FPTreeNode nextLink;
            private FPTreeNode prevLink;
            private FPTreeNode parent;
            private List<FPTreeNode> children;

            public FPTreeNode(int item, FPTreeNode parent, double initialCount) {
                this.item = item;
                this.parent = parent;
                this.count = initialCount;
            }

            public int getItem() {
                return item;
            }

            public double getCount() {
                return count;
            }

            public void incrementCount(double by) {
                count += by;
            }


            public void decrementCount(double by) {
                count -= by;
            }

            public boolean hasChildren() {
                return children != null && children.size() > 0;
            }

            public void removeChild(FPTreeNode child) {
                assert (children.contains(child));
                children.remove(child);
            }

            public void setNextLink(FPTreeNode nextLink) {
                this.nextLink = nextLink;
            }

            public FPTreeNode getNextLink() {
                return nextLink;
            }

            public void setPrevLink(FPTreeNode prevLink) {
                this.prevLink = prevLink;
            }

            public FPTreeNode getPrevLink() {
                return prevLink;
            }

            public FPTreeNode getParent() {
                return parent;
            }


            public List<FPTreeNode> getChildren() {
                return children;
            }

            public void mergeChildren(List<FPTreeNode> otherChildren) {
                assert (!hasChildren() || !leafNodes.contains(this));

                if (otherChildren == null) {
                    return;
                }

                if (children == null) {
                    children = Lists.newArrayList(otherChildren);
                    for (FPTreeNode child : otherChildren) {
                        child.parent = this;
                    }
                    leafNodes.remove(this);

                    return;
                }

                // O(N^2); slow for large lists; consider optimizing
                for (FPTreeNode otherChild : otherChildren) {
                    otherChild.parent = this;
                    boolean matched = false;
                    for (FPTreeNode ourChild : children) {
                        if (otherChild.item == ourChild.item) {
                            removeNodeFromHeaders(otherChild);

                            ourChild.count += otherChild.count;
                            ourChild.mergeChildren(otherChild.getChildren());

                            matched = true;
                            break;
                        }
                    }

                    if (!matched) {
                        children.add(otherChild);
                    }
                }
            }

            // insert the transaction at this node starting with transaction[currentIndex]
            // then find the child that matches
            public void insertTransaction(List<Integer> fullTransaction,
                                          int currentIndex,
                                          final double itemCount,
                                          boolean streaming) {
                if (!streaming) {
                    sortedNodes.add(this);
                }

                incrementCount(itemCount);

                if (currentIndex == fullTransaction.size()) {
                    return;
                }

                int currentItem = fullTransaction.get(currentIndex);

                FPTreeNode matchingChild = null;

                if (children != null) {
                    for (FPTreeNode child : children) {
                        if (child.getItem() == currentItem) {
                            matchingChild = child;
                            break;
                        }
                    }
                }

                if (matchingChild == null) {
                    matchingChild = new FPTreeNode(currentItem, this, 0);

                    if (!streaming) {
                        sortedNodes.add(matchingChild);
                    }

                    FPTreeNode prevHeader = nodeHeaders.get(currentItem);
                    nodeHeaders.put(currentItem, matchingChild);

                    if (prevHeader != null) {
                        matchingChild.setNextLink(prevHeader);
                        prevHeader.setPrevLink(matchingChild);
                    }

                    if (children == null) {
                        children = new ArrayList<>();
                    }

                    children.add(matchingChild);

                    if (currentIndex == fullTransaction.size() - 1) {
                        leafNodes.add(matchingChild);
                    }

                    leafNodes.remove(this);
                }

                matchingChild.insertTransaction(fullTransaction, currentIndex + 1, itemCount, streaming);
            }
        }

        public int getSupport(Collection<Integer> pattern) {
            for (Integer i : pattern) {
                if (!frequentItemCounts.containsKey(i)) {
                    return 0;
                }
            }

            List<Integer> plist = Lists.newArrayList(pattern);
            // traverse bottom to top
            plist.sort((i1, i2) -> frequentItemOrder.get(i1).compareTo(frequentItemOrder.get(i2)));

            int count = 0;
            FPTreeNode pathHead = nodeHeaders.get(plist.get(0));
            while (pathHead != null) {
                FPTreeNode curNode = pathHead;
                int itemsToFind = plist.size();

                while (curNode != null) {
                    if (pattern.contains(curNode.getItem())) {
                        itemsToFind -= 1;
                    }

                    if (itemsToFind == 0) {
                        count += pathHead.count;
                        break;
                    }

                    curNode = curNode.getParent();
                }

                pathHead = pathHead.getNextLink();
            }

            return count;
        }

        public void insertFrequentItems(List<Set<Integer>> transactions,
                                        int countRequiredForSupport) {
            Timer.Context context = insertFrequentItems.time();

            Map<Integer, Double> itemCounts = new HashMap<>();
            for (Set<Integer> t : transactions) {
                for (Integer item : t) {
                    itemCounts.compute(item, (k, v) -> v == null ? 1 : v + 1);
                }
            }

            for (Map.Entry<Integer, Double> e : itemCounts.entrySet()) {
                if (e.getValue() >= countRequiredForSupport) {
                    frequentItemCounts.put(e.getKey(), e.getValue());
                }
            }

            // we have to materialize a canonical order so that items with equal counts
            // are consistently ordered when they are sorted during transaction insertion
            List<Map.Entry<Integer, Double>> sortedItemCounts = Lists.newArrayList(frequentItemCounts.entrySet());
            sortedItemCounts.sort((i1, i2) -> frequentItemCounts.get(i1.getKey())
                    .compareTo(frequentItemCounts.get(i2.getKey())));
            for (int i = 0; i < sortedItemCounts.size(); ++i) {
                frequentItemOrder.put(sortedItemCounts.get(i).getKey(), i);
            }

            context.stop();
        }

        private void deleteItems(Set<Integer> itemsToDelete) {
            if (itemsToDelete == null) {
                return;
            }

            for (int item : itemsToDelete) {
                frequentItemCounts.remove(item);
                frequentItemOrder.remove(item);

                FPTreeNode nodeToDelete = nodeHeaders.get(item);

                while (nodeToDelete != null) {
                    nodeToDelete.parent.removeChild(nodeToDelete);
                    if (nodeToDelete.hasChildren()) {
                        nodeToDelete.parent.mergeChildren(nodeToDelete.children);
                    }

                    leafNodes.remove(nodeToDelete);

                    nodeToDelete = nodeToDelete.getNextLink();
                }

                nodeHeaders.remove(item);
            }
        }

        private void updateFrequentItemOrder() {
            Timer.Context context = updateFrequentItemOrder.time();

            sortedNodes.clear();

            frequentItemOrder.clear();

            // we have to materialize a canonical order so that items with equal counts
            // are consistently ordered when they are sorted during transaction insertion
            List<Map.Entry<Integer, Double>> sortedItemCounts = Lists.newArrayList(frequentItemCounts.entrySet());
            sortedItemCounts.sort((i1, i2) -> frequentItemCounts.get(i1.getKey())
                    .compareTo(frequentItemCounts.get(i2.getKey())));
            for (int i = 0; i < sortedItemCounts.size(); ++i) {
                frequentItemOrder.put(sortedItemCounts.get(i).getKey(), i);
            }

            context.stop();
        }

        public void insertConditionalFrequentItems(List<ItemsetWithCount> patterns,
                                                   int countRequiredForSupport) {
            Map<Integer, Double> itemCounts = new HashMap<>();

            for (ItemsetWithCount i : patterns) {
                for (Integer item : i.getItems()) {
                    itemCounts.compute(item, (k, v) -> v == null ? i.getCount() : v + i.getCount());
                }
            }

            for (Map.Entry<Integer, Double> e : itemCounts.entrySet()) {
                if (e.getValue() >= countRequiredForSupport) {
                    frequentItemCounts.put(e.getKey(), e.getValue());
                }
            }

            updateFrequentItemOrder();
        }

        private void sortTransaction(List<Integer> txn, boolean isStreaming) {
            if (!isStreaming) {
                txn.sort((i1, i2) -> frequentItemOrder.get(i2).compareTo(frequentItemOrder.get(i1)));
            } else {
                txn.sort((i1, i2) ->
                                 frequentItemOrder.compute(i2, (k, v) -> v == null ? -i2 : v)
                                         .compareTo(frequentItemOrder.compute(i1, (k, v) -> v == null ? -i1 : v)));
            }
        }

        public void insertConditionalFrequentPatterns(List<ItemsetWithCount> patterns) {
            for (ItemsetWithCount is : patterns) {
                reinsertBranch(is.getItems(), is.getCount(), root);
            }
        }

        public void reinsertBranch(Set<Integer> pattern, double count, FPTreeNode rootOfBranch) {
            List<Integer> filtered = pattern.stream().filter(i -> frequentItemCounts.containsKey(i)).collect(
                    Collectors.toList());
            sortTransaction(filtered, false);
            rootOfBranch.insertTransaction(filtered, 0, count, false);
        }


        public void insertTransactions(List<Set<Integer>> transactions, boolean streaming, boolean filterExistingFrequentItemsOnly) {
            for (Set<Integer> t : transactions) {
                insertTransaction(t, streaming, filterExistingFrequentItemsOnly);
            }
        }

        public void insertTransaction(Collection<Integer> transaction, boolean streaming, boolean filterExistingFrequentItemsOnly) {
            if (streaming && !filterExistingFrequentItemsOnly) {
                for (Integer item : transaction) {
                    frequentItemCounts.compute(item, (k, v) -> v == null ? 1 : v + 1);
                }
            }

            List<Integer> filtered = transaction.stream().filter(i -> frequentItemCounts.containsKey(i)).collect(
                    Collectors.toList());

            if (!filtered.isEmpty()) {
                if (streaming && filterExistingFrequentItemsOnly) {
                    for (Integer item : filtered) {
                        frequentItemCounts.compute(item, (k, v) -> v == null ? 1 : v + 1);
                    }
                }

                sortTransaction(filtered, streaming);
                root.insertTransaction(filtered, 0, 1, streaming);
            }
        }

        List<ItemsetWithCount> mineItemsets(Integer supportCountRequired) {
            List<ItemsetWithCount> singlePathItemsets = new ArrayList<>();
            List<ItemsetWithCount> branchingItemsets = new ArrayList<>();

            // mine single-path itemsets first
            FPTreeNode curNode = root;
            FPTreeNode nodeOfBranching = null;
            Set<FPTreeNode> singlePathNodes = new HashSet<>();
            while (true) {
                if (curNode.count < supportCountRequired) {
                    break;
                }

                if (curNode.children != null && curNode.children.size() > 1) {
                    nodeOfBranching = curNode;
                    break;
                }

                if (curNode != root) {
                    singlePathNodes.add(curNode);
                }

                if (curNode.children == null || curNode.children.size() == 0) {
                    break;
                } else {
                    curNode = curNode.children.get(0);
                }
            }

            for (Set<FPTreeNode> subset : Sets.powerSet(singlePathNodes)) {
                if (subset.isEmpty()) {
                    continue;
                }

                double minSupportInSubset = -1;
                Set<Integer> items = new HashSet<>();
                for (FPTreeNode n : subset) {
                    items.add(n.getItem());

                    if (minSupportInSubset == -1 || n.getCount() < minSupportInSubset) {
                        minSupportInSubset = n.getCount();
                    }
                }

                assert (minSupportInSubset >= supportCountRequired);
                singlePathItemsets.add(new ItemsetWithCount(items, minSupportInSubset));
            }

            // the entire tree was a single path...
            if (nodeOfBranching == null) {
                return singlePathItemsets;
            }

            // all of the items in the single path will have been mined now
            // due to the descending frequency count of the StreamingFPTree structure, so
            // we remove them from consideration in the rest

            // instead of destructively removing the nodes from NodeHeader table
            // which would be valid but would make mining non-idempotent, we
            // instead store the nodes to skip in a separate set

            Set<Integer> alreadyMinedItems = new HashSet<>();
            for (FPTreeNode node : singlePathNodes) {
                alreadyMinedItems.add(node.getItem());
            }

            for (Map.Entry<Integer, FPTreeNode> header : nodeHeaders.entrySet()) {
                if (alreadyMinedItems.contains(header.getKey())
                    || frequentItemCounts.get(header.getKey()) < supportCountRequired) {
                    continue;
                }

                // add the singleton item set
                branchingItemsets.add(new ItemsetWithCount(Sets.newHashSet(header.getKey()),
                                                           frequentItemCounts.get(header.getKey())));

                List<ItemsetWithCount> conditionalPatternBase = new ArrayList<>();

                // walk each "leaf" node
                FPTreeNode conditionalNode = header.getValue();
                while (conditionalNode != null) {
                    final double leafSupport = conditionalNode.getCount();

                    // walk the tree up to the branch node
                    Set<Integer> conditionalPattern = new HashSet<>();
                    FPTreeNode walkNode = conditionalNode.getParent();
                    while (walkNode != nodeOfBranching.getParent() && walkNode != root) {
                        conditionalPattern.add(walkNode.getItem());
                        walkNode = walkNode.getParent();
                    }

                    if (conditionalPattern.size() > 0) {
                        conditionalPatternBase.add(new ItemsetWithCount(conditionalPattern, leafSupport));
                    }

                    conditionalNode = conditionalNode.getNextLink();
                }

                if (conditionalPatternBase.isEmpty()) {
                    continue;
                }

                // build and mine the conditional StreamingFPTree
                StreamingFPTree conditionalTree = new StreamingFPTree();
                conditionalTree.insertConditionalFrequentItems(conditionalPatternBase, supportCountRequired);
                conditionalTree.insertConditionalFrequentPatterns(conditionalPatternBase);
                List<ItemsetWithCount> conditionalFrequentItemsets = conditionalTree.mineItemsets(supportCountRequired);

                if (!conditionalFrequentItemsets.isEmpty()) {
                    for (ItemsetWithCount is : conditionalFrequentItemsets) {
                        is.getItems().add(header.getKey());
                    }

                    branchingItemsets.addAll(conditionalFrequentItemsets);
                }
            }

            if (singlePathItemsets.isEmpty()) {
                return branchingItemsets;
            }

            // take the cross product of the mined itemsets
            List<ItemsetWithCount> ret = new ArrayList<>();

            ret.addAll(singlePathItemsets);
            ret.addAll(branchingItemsets);

            for (ItemsetWithCount i : singlePathItemsets) {
                for (ItemsetWithCount j : branchingItemsets) {
                    Set<Integer> combinedItems = new HashSet<>();
                    combinedItems.addAll(i.getItems());
                    combinedItems.addAll(j.getItems());
                    ret.add(new ItemsetWithCount(combinedItems, Math.min(i.getCount(), j.getCount())));
                }
            }

            return ret;
        }

        private void removeNodeFromHeaders(FPTreeNode node) {
            leafNodes.remove(node);

            if (node.getPrevLink() == null) {
                assert (nodeHeaders.get(node.getItem()) == node);
                nodeHeaders.put(node.getItem(), node.getNextLink());
            } else {
                node.getPrevLink().setNextLink(node.getNextLink());
            }

            if (node.getNextLink() != null) {
                node.getNextLink().setPrevLink(node.getPrevLink());
            }
        }

        private void sortByNewOrder() {
            // we need to walk the tree from each leaf to each root

            List<FPTreeNode> leavesToInspect = Lists.newArrayList(leafNodes);
            Set<FPTreeNode> removedNodes = new HashSet<>();
            for (int i = 0; i < leavesToInspect.size(); ++i) {
                FPTreeNode leaf = leavesToInspect.get(i);

                if (leaf == root) {
                    continue;
                }

                if (removedNodes.contains(leaf) || sortedNodes.contains(leaf)) {
                    continue;
                }

                double leafCount = leaf.getCount();
                Set<Integer> toInsert = new HashSet<>();

                toInsert.add(leaf.getItem());

                assert (!leaf.hasChildren());

                removeNodeFromHeaders(leaf);

                removedNodes.add(leaf);

                int curLowestNodeOrder = frequentItemOrder.get(leaf.getItem());

                FPTreeNode node = leaf.getParent();
                node.removeChild(leaf);
                while (true) {
                    if (node == root) {
                        break;
                    }

                    int nodeOrder = frequentItemOrder.get(node.getItem());
                    if (sortedNodes.contains(node) && nodeOrder < curLowestNodeOrder) {
                        break;
                    } else if (nodeOrder < curLowestNodeOrder) {
                        curLowestNodeOrder = nodeOrder;
                    }

                    assert (!removedNodes.contains(node));

                    toInsert.add(node.getItem());

                    node.decrementCount(leafCount);
                    // this node no longer has support, so remove it...
                    if (node.getCount() == 0 && !node.hasChildren()) {
                        removedNodes.add(node);
                        removeNodeFromHeaders(node);
                        node.getParent().removeChild(node);
                        // still has support but is unsorted, so we'd better check it out
                    } else if (!node.hasChildren() && !sortedNodes.contains(node)) {
                        leavesToInspect.add(node);
                    }

                    node = node.getParent();
                }

                node.decrementCount(leafCount);

                reinsertBranch(toInsert, leafCount, node);
            }
        }
    }

    public void insertTransactionsStreamingExact(List<Set<Integer>> transactions) {
        needsRestructure = true;

        fp.insertTransactions(transactions, true, false);
    }

    public void insertTransactionStreamingExact(Collection<Integer> transaction) {
        needsRestructure = true;

        fp.insertTransaction(transaction, true, false);
    }

    public void insertTransactionsStreamingFalseNegative(List<Set<Integer>> transactions) {
        needsRestructure = true;

        fp.insertTransactions(transactions, true, true);
    }

    public void insertTransactionStreamingFalseNegative(Collection<Integer> transaction) {
        needsRestructure = true;

        fp.insertTransaction(transaction, true, true);
    }

    public void restructureTree(Set<Integer> itemsToDelete) {
        needsRestructure = false;
        // todo: prune infrequent items
        Timer.Context context = restructureTree.time();

        fp.deleteItems(itemsToDelete);
        fp.updateFrequentItemOrder();
        fp.sortByNewOrder();

        context.stop();
    }

    public void buildTree(List<Set<Integer>> transactions) {
        if (startedStreaming) {
            throw new RuntimeException("Can't build a tree based on an already streaming tree...");
        }

        int countRequiredForSupport = (int) (support * transactions.size());

        fp.insertFrequentItems(transactions, countRequiredForSupport);
        fp.insertTransactions(transactions, false, false);
    }

    public void decayAndResetFrequentItems(Map<Integer, Double> newFrequentItems, double decayRate) {
        Set<Integer> toRemove = Sets.difference(fp.frequentItemOrder.keySet(),
                                                newFrequentItems.keySet()).immutableCopy();
        fp.frequentItemCounts = newFrequentItems;
        fp.updateFrequentItemOrder();
        if (decayRate > 0) {
            fp.decayWeights(fp.root, (1 - decayRate));
        }
        restructureTree(toRemove);
    }

    public List<ItemsetWithCount> getCounts(List<ItemsetWithCount> targets) {
        if(needsRestructure){
            restructureTree(null);
        }

        List<ItemsetWithCount> ret = new ArrayList<>(targets.size());
        for (ItemsetWithCount target : targets) {
            ret.add(new ItemsetWithCount(target.getItems(), fp.getSupport(target.getItems())));
        }
        return ret;
    }

    public List<ItemsetWithCount> getItemsets() {
        if (needsRestructure) {
            restructureTree(null);
        }

        Timer.Context context = fpMine.time();
        List<ItemsetWithCount> itemset = fp.mineItemsets((int) (fp.root.getCount() * support));
        context.stop();

        return itemset;
    }

    public void printTreeDebug() {
        fp.printTreeDebug();
    }
}