/**
 * Copyright (c) 2015 Lemur Consulting Ltd.
 * <p/>
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * <p/>
 * http://www.apache.org/licenses/LICENSE-2.0
 * <p/>
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package uk.co.flax.biosolr.pruning;

import java.util.Collection;
import java.util.Comparator;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.stream.Collectors;

import uk.co.flax.biosolr.TreeFacetField;

/**
 * Prune a facet hierarchy tree into its most significant data points,
 * with all other points grouped into "other".
 *
 * @author mlp
 */
public class DatapointPruner implements Pruner {
	
	public static final String DEFAULT_MORE_LABEL = "Others";
	
	private final int datapoints;
	private final String moreLabel;

	public DatapointPruner(int datapoints, String moreLabel) {
		this.datapoints = datapoints;
		this.moreLabel = moreLabel;
	}

	@Override
	public Collection<TreeFacetField> prune(Collection<TreeFacetField> unprunedTrees) {
		Collection<TreeFacetField> prunedTrees = new TreeSet<>(Comparator.comparingLong(TreeFacetField::getCount)
				.thenComparing(TreeFacetField::getValue).reversed());
		// Clone the unpruned collection - we need it again later
		Collection<TreeFacetField> incoming = unprunedTrees.stream().map(TreeFacetField::clone).collect(Collectors.toList());

		long total = getNodeTotal(incoming);
		int itCount = 1;
		int prevCount = Integer.MAX_VALUE;
		
		while (prunedTrees.size() < datapoints && !incoming.isEmpty()) {
			int minCount = getThreshold(itCount, prevCount, total);
			if (minCount <= 0) {
				break;
			}
			
			prunedTrees.addAll(getNodesWithCount(incoming, minCount));
			
			itCount ++;
			prevCount = minCount;
		}
		
		/* Trim the pruned trees list to the number of datapoints.
		 * This leaves the incoming list copy potentially missing nodes which
		 * should be in the "other" node. Since they could be anywhere, we
		 * have to rebuild it from scratch.
		 */
		if (prunedTrees.size() > datapoints) {
			prunedTrees = prunedTrees.stream().limit(datapoints).collect(Collectors.toList());
		}
		
		// Rebuild the incoming node set - no need to clone...
		incoming = new LinkedList<>(unprunedTrees);
		// ...and strip the nodes already extracted to the pruned list
		trimIncomingNodes(incoming, prunedTrees, 0);
		
		// Build the "other" node
		TreeFacetField otherNode = buildOtherNode(incoming);
		if (otherNode.getTotal() > 0) {
			prunedTrees.add(otherNode);
		}
		
		return prunedTrees;
	}
	
	private int getThreshold(int iteration, int previous, long total) {
		int min = Math.min(Math.round((total / datapoints) / iteration), previous - 1);
		
		if (min == 0 && iteration == 1) {
			// First iteration - set minCount to 1
			min = 1;
		}
		
		return min;
	}
	
	/**
	 * Extract all nodes in a collection with a hit count greater or equal
	 * to a given threshold. This has the side effect of modifying the
	 * incoming node collection.
	 * @param incoming the incoming nodes. Matching nodes will be removed
	 * during the processing.
	 * @param threshold the minimum hit count required to be returned.
	 * @return the collection of nodes whose hit count is greater than or
	 * equal to the threshold. 
	 */
	private Collection<TreeFacetField> getNodesWithCount(Collection<TreeFacetField> incoming, long threshold) {
		Collection<TreeFacetField> retList = new LinkedList<>();
		
		for (Iterator<TreeFacetField> iter = incoming.iterator(); iter.hasNext(); ) {
			TreeFacetField tff = iter.next();
			if (tff.getTotal() >= threshold) {
				if (tff.getChildCount() >= threshold) {
					// Recurse, finding the nodes with enough hits
					retList.addAll(getNodesWithCount(tff.getHierarchy(), threshold));
					// Recalculate the child count throughout the tree
					tff.recalculateChildCount();
				}
				
				if (tff.getCount() >= threshold) {
					// This node has enough hits - store, and remove from the 
					// incoming nodes so it's not picked again later.
					retList.add(tff);
					iter.remove();
				}
			}
		}
		
		return retList;
	}
	
	/**
	 * Get the total node count for all trees.
	 * @param trees the trees whose total count is required.
	 * @return the count.
	 */
	private long getNodeTotal(Collection<TreeFacetField> trees) {
		return trees.stream().mapToLong(TreeFacetField::getTotal).sum();
	}
	
	/**
	 * Remove a collection of pruned nodes from the original incoming set.
	 * @param incoming the set containing all nodes in the tree.
	 * @param pruned the nodes to check for duplicates.
	 * @param level the current level in the tree, starting from 0.
	 */
	private void trimIncomingNodes(Collection<TreeFacetField> incoming, Collection<TreeFacetField> pruned, int level) {
		for (Iterator<TreeFacetField> it = incoming.iterator(); it.hasNext(); ) {
			TreeFacetField tff = it.next();
			if (isFacetInChildren(tff, pruned)) {
				it.remove();
			} else {
				if (tff.hasChildren()) {
					trimIncomingNodes(tff.getHierarchy(), pruned, level + 1);
				}

				if (level == 0) {
					// Update the child counts in the node and its children
					tff.recalculateChildCount();
				}
			}
		}
	}
	
	/**
	 * Check whether a particular facet exists in the children of any other facets
	 * in a collection.
	 * @param facet the facet to check for.
	 * @param trees the collection of trees to check through.
	 * @return <code>true</code> if the facet is found in the child lists.
	 */
	private boolean isFacetInChildren(TreeFacetField facet, Collection<TreeFacetField> trees) {
		boolean retVal = false;
		
		if (trees != null) {
			for (TreeFacetField tree : trees) {
				if (tree.equals(facet) || isFacetInChildren(facet, tree.getHierarchy())) {
					retVal = true;
					break;
				}
			}
		}
		
		return retVal;
	}
	
	private TreeFacetField buildOtherNode(Collection<TreeFacetField> otherNodes) {
		// Prune the other nodes - use the SimplePruner
		SortedSet<TreeFacetField> pruned = new TreeSet<>(Comparator.reverseOrder());
		pruned.addAll(new SimplePruner(SimplePruner.MIN_CHILD_COUNT).prune(otherNodes));
		
		TreeFacetField other = new TreeFacetField(moreLabel, "", 0, 0, pruned);
		other.recalculateChildCount();
		
		return other;
	}

}