package edu.berkeley.nlp.syntax;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;

import edu.berkeley.nlp.util.CollectionUtils;
import edu.berkeley.nlp.util.MapFactory;
import edu.berkeley.nlp.util.MyMethod;
import edu.berkeley.nlp.util.Pair;

/**
 * Represent linguistic trees, with each node consisting of a label and a list
 * of children.
 * 
 * @author Dan Klein
 * 
 *         Added function to get a map of subtrees to constituents.
 */
public class Tree<L> implements Serializable, Comparable<Tree<L>>,
		Iterable<Tree<L>> {

	private static final long serialVersionUID = 1L;

	L label;

	List<Tree<L>> children;

	public void setChild(int i, Tree<L> child) {
		children.set(i, child);
	}

	public void setChildren(List<Tree<L>> c) {
		this.children = c;
	}

	public List<Tree<L>> getChildren() {
		return children;
	}

	public Tree<L> getChild(int i) {
		return children.get(i);
	}

	public L getLabel() {
		return label;
	}

	public boolean isLeaf() {
		return getChildren().isEmpty();
	}

	public boolean isPreTerminal() {
		return getChildren().size() == 1 && getChildren().get(0).isLeaf();
	}

	public List<L> getYield() {
		List<L> yield = new ArrayList<L>();
		appendYield(this, yield);
		return yield;
	}

	public Collection<Constituent<L>> getConstituentCollection() {
		Collection<Constituent<L>> constituents = new ArrayList<Constituent<L>>();
		appendConstituent(this, constituents, 0);
		return constituents;
	}

	/**
	 * John: I changed this from a hash map because it was broken as a HashMap.
	 */
	public Map<Tree<L>, Constituent<L>> getConstituents() {
		Map<Tree<L>, Constituent<L>> constituents = new IdentityHashMap<Tree<L>, Constituent<L>>();
		appendConstituent(this, constituents, 0);
		return constituents;
	}

	public Map<Pair<Integer, Integer>, List<Tree<L>>> getSpanMap() {
		Map<Tree<L>, Constituent<L>> cMap = getConstituents();
		Map<Pair<Integer, Integer>, List<Tree<L>>> spanMap = new HashMap();
		for (Map.Entry<Tree<L>, Constituent<L>> entry : cMap.entrySet()) {
			Tree<L> t = entry.getKey();
			Constituent<L> c = entry.getValue();
			Pair<Integer, Integer> span = Pair.newPair(c.getStart(),
					c.getEnd() + 1);
			CollectionUtils.addToValueList(spanMap, span, t);
		}
		for (List<Tree<L>> trees : spanMap.values()) {
			Collections.sort(trees, new Comparator<Tree<L>>() {
				public int compare(Tree<L> t1, Tree<L> t2) {
					return t2.getDepth() - t1.getDepth();
				}
			});
		}
		return spanMap;
	}

	public Map<Tree<L>, Constituent<L>> getConstituents(MapFactory mf) {
		Map<Tree<L>, Constituent<L>> constituents = mf.buildMap();
		appendConstituent(this, constituents, 0);
		return constituents;
	}

	private static <L> int appendConstituent(Tree<L> tree,
			Map<Tree<L>, Constituent<L>> constituents, int index) {
		if (tree.isLeaf()) {
			Constituent<L> c = new Constituent<L>(tree.getLabel(), index, index);
			constituents.put(tree, c);
			return 1; // Length of a leaf constituent
		} else {
			int nextIndex = index;
			for (Tree<L> kid : tree.getChildren()) {
				nextIndex += appendConstituent(kid, constituents, nextIndex);
			}
			Constituent<L> c = new Constituent<L>(tree.getLabel(), index,
					nextIndex - 1);
			constituents.put(tree, c);
			return nextIndex - index; // Length of a leaf constituent
		}
	}

	private static <L> int appendConstituent(Tree<L> tree,
			Collection<Constituent<L>> constituents, int index) {
		if (tree.isLeaf() || tree.isPreTerminal()) {
			Constituent<L> c = new Constituent<L>(tree.getLabel(), index, index);
			constituents.add(c);
			return 1; // Length of a leaf constituent
		} else {
			int nextIndex = index;
			for (Tree<L> kid : tree.getChildren()) {
				nextIndex += appendConstituent(kid, constituents, nextIndex);
			}
			Constituent<L> c = new Constituent<L>(tree.getLabel(), index,
					nextIndex - 1);
			constituents.add(c);
			return nextIndex - index; // Length of a leaf constituent
		}
	}

	private static <L> void appendNonTerminals(Tree<L> tree, List<Tree<L>> yield) {
		if (tree.isLeaf()) {

			return;
		}
		yield.add(tree);
		for (Tree<L> child : tree.getChildren()) {
			appendNonTerminals(child, yield);
		}
	}

	public List<Tree<L>> getTerminals() {
		List<Tree<L>> yield = new ArrayList<Tree<L>>();
		appendTerminals(this, yield);
		return yield;
	}

	public List<Tree<L>> getNonTerminals() {
		List<Tree<L>> yield = new ArrayList<Tree<L>>();
		appendNonTerminals(this, yield);
		return yield;
	}

	private static <L> void appendTerminals(Tree<L> tree, List<Tree<L>> yield) {
		if (tree.isLeaf()) {
			yield.add(tree);
			return;
		}
		for (Tree<L> child : tree.getChildren()) {
			appendTerminals(child, yield);
		}
	}

	/**
	 * Clone the structure of the tree. Unfortunately, the new labels are copied
	 * by reference from the current tree.
	 * 
	 * @return
	 */
	public Tree<L> shallowClone() {
		ArrayList<Tree<L>> newChildren = new ArrayList<Tree<L>>(children.size());
		for (Tree<L> child : children) {
			newChildren.add(child.shallowClone());
		}
		return new Tree<L>(label, newChildren);
	}

	/**
	 * Return a clone of just the root node of this tree (with no children)
	 * 
	 * @return
	 */
	public Tree<L> shallowCloneJustRoot() {

		return new Tree<L>(label);
	}

	private static <L> void appendYield(Tree<L> tree, List<L> yield) {
		if (tree.isLeaf()) {
			yield.add(tree.getLabel());
			return;
		}
		for (Tree<L> child : tree.getChildren()) {
			appendYield(child, yield);
		}
	}

	public List<L> getPreTerminalYield() {
		List<L> yield = new ArrayList<L>();
		appendPreTerminalYield(this, yield);
		return yield;
	}

	public List<L> getTerminalYield() {
		List<Tree<L>> terms = getTerminals();
		List<L> yield = new ArrayList<L>();
		for (Tree<L> term : terms) {
			yield.add(term.getLabel());
		}
		return yield;
	}

	public List<Tree<L>> getPreTerminals() {
		List<Tree<L>> preterms = new ArrayList<Tree<L>>();
		appendPreTerminals(this, preterms);
		return preterms;
	}

	public List<Tree<L>> getTreesOfDepth(int depth) {
		List<Tree<L>> trees = new ArrayList<Tree<L>>();
		appendTreesOfDepth(this, trees, depth);
		return trees;
	}

	private static <L> void appendPreTerminalYield(Tree<L> tree, List<L> yield) {
		if (tree.isPreTerminal()) {
			yield.add(tree.getLabel());
			return;
		}
		for (Tree<L> child : tree.getChildren()) {
			appendPreTerminalYield(child, yield);
		}
	}

	private static <L> void appendPreTerminals(Tree<L> tree, List<Tree<L>> yield) {
		if (tree.isPreTerminal()) {
			yield.add(tree);
			return;
		}
		for (Tree<L> child : tree.getChildren()) {
			appendPreTerminals(child, yield);
		}
	}

	private static <L> void appendTreesOfDepth(Tree<L> tree,
			List<Tree<L>> yield, int depth) {
		if (tree.getDepth() == depth) {
			yield.add(tree);
			return;
		}
		for (Tree<L> child : tree.getChildren()) {
			appendTreesOfDepth(child, yield, depth);
		}
	}

	public List<Tree<L>> getPreOrderTraversal() {
		ArrayList<Tree<L>> traversal = new ArrayList<Tree<L>>();
		traversalHelper(this, traversal, true);
		return traversal;
	}

	public List<Tree<L>> getPostOrderTraversal() {
		ArrayList<Tree<L>> traversal = new ArrayList<Tree<L>>();
		traversalHelper(this, traversal, false);
		return traversal;
	}

	private static <L> void traversalHelper(Tree<L> tree,
			List<Tree<L>> traversal, boolean preOrder) {
		if (preOrder)
			traversal.add(tree);
		for (Tree<L> child : tree.getChildren()) {
			traversalHelper(child, traversal, preOrder);
		}
		if (!preOrder)
			traversal.add(tree);
	}

	public int getDepth() {
		int maxDepth = 0;
		for (Tree<L> child : children) {
			int depth = child.getDepth();
			if (depth > maxDepth)
				maxDepth = depth;
		}
		return maxDepth + 1;
	}

	public int size() {
		int sum = 0;
		for (Tree<L> child : children) {
			sum += child.size();
		}
		return sum + 1;
	}

	public List<Tree<L>> getAtDepth(int depth) {
		List<Tree<L>> yield = new ArrayList<Tree<L>>();
		appendAtDepth(depth, this, yield);
		return yield;
	}

	private static <L> void appendAtDepth(int depth, Tree<L> tree,
			List<Tree<L>> yield) {
		if (depth < 0)
			return;
		if (depth == 0) {
			yield.add(tree);
			return;
		}
		for (Tree<L> child : tree.getChildren()) {
			appendAtDepth(depth - 1, child, yield);
		}
	}

	public void setLabel(L label) {
		this.label = label;
	}

	@Override
	public String toString() {
		StringBuilder sb = new StringBuilder();
		toStringBuilder(sb);
		return sb.toString();
	}

	public void toStringBuilder(StringBuilder sb) {
		if (!isLeaf())
			sb.append('(');
		if (getLabel() != null) {
			sb.append(getLabel());
		}
		if (!isLeaf()) {
			for (Tree<L> child : getChildren()) {
				sb.append(' ');
				child.toStringBuilder(sb);
			}
			sb.append(')');
		}
	}

	/**
	 * Same as toString(), but escapes terminals like so: ( becomes -LRB- )
	 * becomes -RRB- \ becomes -BACKSLASH- ("\" does not occur in PTB; this is
	 * our own convention) This is useful because otherwise it's hard to tell a
	 * "(" terminal from the tree's bracket structure, or tell an escaping \
	 * from a literal.
	 */
	public String toEscapedString() {
		StringBuilder sb = new StringBuilder();
		toStringBuilderEscaped(sb);
		return sb.toString();
	}

	public void toStringBuilderEscaped(StringBuilder sb) {
		if (!isLeaf())
			sb.append('(');
		if (getLabel() != null) {
			if (isLeaf()) {
				String escapedLabel = getLabel().toString();
				escapedLabel = escapedLabel.replaceAll("\\(", "-LRB-");
				escapedLabel = escapedLabel.replaceAll("\\)", "-RRB-");
				escapedLabel = escapedLabel.replaceAll("\\\\", "-BACKSLASH-");
				sb.append(escapedLabel);
			} else {
				sb.append(getLabel());
			}
		}
		if (!isLeaf()) {
			for (Tree<L> child : getChildren()) {
				sb.append(' ');
				child.toStringBuilderEscaped(sb);
			}
			sb.append(')');
		}
	}

	public Tree(L label, List<Tree<L>> children) {
		this.label = label;
		this.children = children;
	}

	public Tree(L label) {
		this.label = label;
		this.children = Collections.emptyList();
	}

	/**
	 * Get the set of all subtrees inside the tree by returning a tree rooted at
	 * each node. These are <i>not</i> copies, but all share structure. The tree
	 * is regarded as a subtree of itself.
	 * 
	 * @return the <code>Set</code> of all subtrees in the tree.
	 */
	public Set<Tree<L>> subTrees() {
		return (Set<Tree<L>>) subTrees(new HashSet<Tree<L>>());
	}

	/**
	 * Get the list of all subtrees inside the tree by returning a tree rooted
	 * at each node. These are <i>not</i> copies, but all share structure. The
	 * tree is regarded as a subtree of itself.
	 * 
	 * @return the <code>List</code> of all subtrees in the tree.
	 */
	public List<Tree<L>> subTreeList() {
		return (List<Tree<L>>) subTrees(new ArrayList<Tree<L>>());
	}

	/**
	 * Add the set of all subtrees inside a tree (including the tree itself) to
	 * the given <code>Collection</code>.
	 * 
	 * @param n
	 *            A collection of nodes to which the subtrees will be added
	 * @return The collection parameter with the subtrees added
	 */
	public Collection<Tree<L>> subTrees(Collection<Tree<L>> n) {
		n.add(this);
		List<Tree<L>> kids = getChildren();
		for (Tree<L> kid : kids) {
			kid.subTrees(n);
		}
		return n;
	}

	/**
	 * Returns an iterator over the nodes of the tree. This method implements
	 * the <code>iterator()</code> method required by the
	 * <code>Collections</code> interface. It does a preorder (children after
	 * node) traversal of the tree. (A possible extension to the class at some
	 * point would be to allow different traversal orderings via variant
	 * iterators.)
	 * 
	 * @return An iterator over the nodes of the tree
	 */
	public Iterator<Tree<L>> iterator() {
		return new TreeIterator();
	}

	private class TreeIterator implements Iterator<Tree<L>> {

		private List<Tree<L>> treeStack;

		private TreeIterator() {
			treeStack = new ArrayList<Tree<L>>();
			treeStack.add(Tree.this);
		}

		public boolean hasNext() {
			return (!treeStack.isEmpty());
		}

		public Tree<L> next() {
			int lastIndex = treeStack.size() - 1;
			Tree<L> tr = treeStack.remove(lastIndex);
			List<Tree<L>> kids = tr.getChildren();
			// so that we can efficiently use one List, we reverse them
			for (int i = kids.size() - 1; i >= 0; i--) {
				treeStack.add(kids.get(i));
			}
			return tr;
		}

		/**
		 * Not supported
		 */
		public void remove() {
			throw new UnsupportedOperationException();
		}

	}

	/**
	 * Applies a transformation to all labels in the tree and returns the
	 * resulting tree.
	 * 
	 * @param <O>
	 *            Output type of the transformation
	 * @param trans
	 *            The transformation to apply
	 * @return Transformed tree
	 */
	public <O> Tree<O> transformNodes(MyMethod<L, O> trans) {
		ArrayList<Tree<O>> newChildren = new ArrayList<Tree<O>>(children.size());
		for (Tree<L> child : children) {
			newChildren.add(child.transformNodes(trans));
		}
		return new Tree<O>(trans.call(label), newChildren);
	}

	/**
	 * Applies a transformation to all nodes in the tree and returns the
	 * resulting tree. Different from <code>transformNodes</code> in that you
	 * get the full node and not just the label
	 * 
	 * @param <O>
	 * @param trans
	 * @return
	 */
	public <O> Tree<O> transformNodesUsingNode(MyMethod<Tree<L>, O> trans) {
		ArrayList<Tree<O>> newChildren = new ArrayList<Tree<O>>(children.size());
		O newLabel = trans.call(this);
		for (Tree<L> child : children) {
			newChildren.add(child.transformNodesUsingNode(trans));
		}
		return new Tree<O>(newLabel, newChildren);
	}

	public <O> Tree<O> transformNodesUsingNodePostOrder(
			MyMethod<Tree<L>, O> trans) {
		ArrayList<Tree<O>> newChildren = new ArrayList<Tree<O>>(children.size());
		for (Tree<L> child : children) {
			newChildren.add(child.transformNodesUsingNode(trans));
		}
		O newLabel = trans.call(this);
		return new Tree<O>(newLabel, newChildren);
	}

	@Override
	public int hashCode() {
		final int prime = 31;
		int result = 1;
		result = prime * result + ((label == null) ? 0 : label.hashCode());
		for (Tree<L> child : children) {
			result = prime * result + ((child == null) ? 0 : child.hashCode());
		}
		return result;
	}

	@Override
	public boolean equals(Object obj) {
		if (this == obj)
			return true;
		if (obj == null)
			return false;
		if (getClass() != obj.getClass())
			return false;
		if (!(obj instanceof Tree))
			return false;
		final Tree<L> other = (Tree<L>) obj;
		if (!this.label.equals(other.label))
			return false;
		if (this.getChildren().size() != other.getChildren().size())
			return false;
		for (int i = 0; i < getChildren().size(); ++i) {

			if (!getChildren().get(i).equals(other.getChildren().get(i)))
				return false;
		}
		return true;

	}

	public int compareTo(Tree<L> o) {
		if (!(o.getLabel() instanceof Comparable && getLabel() instanceof Comparable))
			throw new IllegalArgumentException("Tree labels are not comparable");
		int cmp = ((Comparable) o.getLabel()).compareTo(getLabel());
		if (cmp != 0)
			return cmp;
		int cmp2 = Double.compare(this.getChildren().size(), o.getChildren()
				.size());
		if (cmp2 != 0)
			return cmp2;
		for (int i = 0; i < getChildren().size(); ++i) {

			int cmp3 = getChildren().get(i).compareTo(o.getChildren().get(i));
			if (cmp3 != 0)
				return cmp3;
		}
		return 0;

	}

	public boolean isPhrasal() {
		return getYield().size() > 1;
	}

	public Constituent<L> getLeastCommonAncestorConstituent(int i, int j) {
		final List<L> yield = getYield();
		final Constituent<L> leastCommonAncestorConstituentHelper = getLeastCommonAncestorConstituentHelper(
				this, 0, yield.size(), i, j);

		return leastCommonAncestorConstituentHelper;
	}

	public Tree<L> getTopTreeForSpan(int i, int j) {
		final List<L> yield = getYield();
		return getTopTreeForSpanHelper(this, 0, yield.size(), i, j);
	}

	private static <L> Tree<L> getTopTreeForSpanHelper(Tree<L> tree, int start,
			int end, int i, int j) {

		assert i <= j;
		if (start == i && end == j) {
			assert tree.getLabel().toString().matches("\\w+");
			return tree;
		}

		Queue<Tree<L>> queue = new LinkedList<Tree<L>>();
		queue.addAll(tree.getChildren());
		int currStart = start;
		while (!queue.isEmpty()) {
			Tree<L> remove = queue.remove();
			List<L> currYield = remove.getYield();
			final int currEnd = currStart + currYield.size();
			if (currStart <= i && currEnd >= j)
				return getTopTreeForSpanHelper(remove, currStart, currEnd, i, j);
			currStart += currYield.size();
		}
		return null;
	}

	private static <L> Constituent<L> getLeastCommonAncestorConstituentHelper(
			Tree<L> tree, int start, int end, int i, int j) {

		if (start == i && end == j)
			return new Constituent<L>(tree.getLabel(), start, end);

		Queue<Tree<L>> queue = new LinkedList<Tree<L>>();
		queue.addAll(tree.getChildren());
		int currStart = start;
		while (!queue.isEmpty()) {
			Tree<L> remove = queue.remove();
			List<L> currYield = remove.getYield();
			final int currEnd = currStart + currYield.size();
			if (currStart <= i && currEnd >= j) {
				final Constituent<L> leastCommonAncestorConstituentHelper = getLeastCommonAncestorConstituentHelper(
						remove, currStart, currEnd, i, j);
				if (leastCommonAncestorConstituentHelper != null)
					return leastCommonAncestorConstituentHelper;
				else
					break;
			}
			currStart += currYield.size();
		}
		return new Constituent<L>(tree.getLabel(), start, end);
	}

	public boolean hasUnariesOtherThanRoot() {
		assert children.size() == 1;
		return hasUnariesHelper(children.get(0));

	}

	private boolean hasUnariesHelper(Tree<L> tree) {
		if (tree.isPreTerminal())
			return false;
		if (tree.getChildren().size() == 1)
			return true;
		for (Tree<L> child : tree.getChildren()) {
			if (hasUnariesHelper(child))
				return true;
		}
		return false;
	}

	public boolean hasUnaryChain() {
		return hasUnaryChainHelper(this, false);
	}

	private boolean hasUnaryChainHelper(Tree<L> tree, boolean unaryAbove) {
		boolean result = false;
		if (tree.getChildren().size() == 1) {
			if (unaryAbove)
				return true;
			else if (tree.getChildren().get(0).isPreTerminal())
				return false;
			else
				return hasUnaryChainHelper(tree.getChildren().get(0), true);
		} else {
			for (Tree<L> child : tree.getChildren()) {
				if (!child.isPreTerminal())
					result = result || hasUnaryChainHelper(child, false);
			}
		}
		return result;
	}

	public void removeUnaryChains() {
		removeUnaryChainHelper(this, null);
	}

	private void removeUnaryChainHelper(Tree<L> tree, Tree<L> parent) {
		if (tree.isLeaf())
			return;
		if (tree.getChildren().size() == 1 && !tree.isPreTerminal()) {
			if (parent != null) {
				tree = tree.getChildren().get(0);
				parent.getChildren().set(0, tree);
				removeUnaryChainHelper(tree, parent);
			} else
				removeUnaryChainHelper(tree.getChildren().get(0), tree);
		} else {
			for (Tree<L> child : tree.getChildren()) {
				if (!child.isPreTerminal())
					removeUnaryChainHelper(child, null);
			}
		}
	}

}