package org.lumongo.server.search;

import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.queryparser.classic.ParseException;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.BoostQuery;
import org.apache.lucene.search.DisjunctionMaxQuery;
import org.apache.lucene.search.MultiPhraseQuery;
import org.apache.lucene.search.PhraseQuery;
import org.apache.lucene.search.Query;
import org.lumongo.server.config.IndexConfig;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;

/**
 * Created by Matt Davis on 5/14/16.
 * @author mdavis
 * Copied mostly from org.apache.lucene.queryparser.classic.MultiFieldQueryParser
 */
public class LumongoMultiFieldQueryParser extends LumongoQueryParser {

	protected List<String> fields;
	protected Map<String, Float> boosts;
	private float dismaxTie = 0;
	private boolean dismax = false;

	public LumongoMultiFieldQueryParser(Analyzer analyzer, IndexConfig indexConfig) {
		super(analyzer, indexConfig);
	}

	public void enableDismax(float dismaxTie) {
		this.dismaxTie = dismaxTie;
		this.dismax = true;
	}

	public void disableDismax() {
		this.dismax = false;
	}

	public void setDefaultFields(Collection<String> fields) {
		setDefaultFields(fields, null);
	}

	public void setDefaultFields(Collection<String> fields, Map<String, Float> boosts) {
		this.field = null;
		this.fields = new ArrayList<>(fields);
		this.boosts = boosts;
	}

	@Override
	public void setDefaultField(String field) {
		super.setDefaultField(field);
		this.fields = null;
	}

	@Override
	protected Query getFieldQuery(String field, String queryText, int slop) throws ParseException {
		if (field == null) {
			List<Query> clauses = new ArrayList<>();
			for (String f : fields) {
				Query q = super.getFieldQuery(f, queryText, true);
				if (q != null) {
					//If the user passes a map of boosts
					if (boosts != null) {
						//Get the boost from the map and apply them
						Float boost = boosts.get(f);
						if (boost != null) {
							q = new BoostQuery(q, boost);
						}
					}
					q = applySlop(q, slop);
					clauses.add(q);
				}
			}
			if (clauses.size() == 0)  // happens for stopwords
				return null;
			return getMultiFieldQuery(clauses);
		}
		Query q = super.getFieldQuery(field, queryText, true);
		q = applySlop(q, slop);
		return q;
	}

	private Query applySlop(Query q, int slop) {
		if (q instanceof PhraseQuery) {
			PhraseQuery.Builder builder = new PhraseQuery.Builder();
			builder.setSlop(slop);
			PhraseQuery pq = (PhraseQuery) q;
			org.apache.lucene.index.Term[] terms = pq.getTerms();
			int[] positions = pq.getPositions();
			for (int i = 0; i < terms.length; ++i) {
				builder.add(terms[i], positions[i]);
			}
			q = builder.build();
		}
		else if (q instanceof MultiPhraseQuery) {
			MultiPhraseQuery mpq = (MultiPhraseQuery) q;

			if (slop != mpq.getSlop()) {
				q = new MultiPhraseQuery.Builder(mpq).setSlop(slop).build();
			}
		}
		return q;
	}

	@Override
	protected Query getFieldQuery(String field, String queryText, boolean quoted) throws ParseException {
		if (field == null) {
			List<Query> clauses = new ArrayList<>();
			Query[] fieldQueries = new Query[fields.size()];
			int maxTerms = 0;
			for (int i = 0; i < fields.size(); i++) {
				Query q = super.getFieldQuery(fields.get(i), queryText, quoted);
				if (q != null) {
					if (q instanceof BooleanQuery) {
						maxTerms = Math.max(maxTerms, ((BooleanQuery) q).clauses().size());
					}
					else {
						maxTerms = Math.max(1, maxTerms);
					}
					fieldQueries[i] = q;
				}
			}
			for (int termNum = 0; termNum < maxTerms; termNum++) {
				List<Query> termClauses = new ArrayList<>();
				for (int i = 0; i < fields.size(); i++) {
					if (fieldQueries[i] != null) {
						Query q = null;
						if (fieldQueries[i] instanceof BooleanQuery) {
							List<BooleanClause> nestedClauses = ((BooleanQuery) fieldQueries[i]).clauses();
							if (termNum < nestedClauses.size()) {
								q = nestedClauses.get(termNum).getQuery();
							}
						}
						else if (termNum == 0) { // e.g. TermQuery-s
							q = fieldQueries[i];
						}
						if (q != null) {
							if (boosts != null) {
								//Get the boost from the map and apply them
								Float boost = boosts.get(fields.get(i));
								if (boost != null) {
									q = new BoostQuery(q, boost);
								}
							}
							termClauses.add(q);
						}
					}
				}
				if (maxTerms > 1) {
					if (termClauses.size() > 0) {
						//mdavis - don't use super method because of min match
						BooleanQuery.Builder builder = new BooleanQuery.Builder();
						for (Query termClause : termClauses) {
							builder.add(termClause, BooleanClause.Occur.SHOULD);
						}
						clauses.add(builder.build());
					}
				}
				else {
					clauses.addAll(termClauses);
				}
			}
			if (clauses.size() == 0)  // happens for stopwords
				return null;
			return getMultiFieldQuery(clauses);
		}
		Query q = super.getFieldQuery(field, queryText, quoted);
		return q;
	}

	@Override
	protected Query getFuzzyQuery(String field, String termStr, float minSimilarity) throws ParseException {
		if (field == null) {
			List<Query> clauses = new ArrayList<>();
			for (String f : fields) {
				clauses.add(getFuzzyQuery(f, termStr, minSimilarity));
			}
			return getMultiFieldQuery(clauses);
		}
		return super.getFuzzyQuery(field, termStr, minSimilarity);
	}

	@Override
	protected Query getPrefixQuery(String field, String termStr) throws ParseException {
		if (field == null) {
			List<Query> clauses = new ArrayList<>();
			for (String f : fields) {
				clauses.add(getPrefixQuery(f, termStr));
			}
			return getMultiFieldQuery(clauses);
		}
		return super.getPrefixQuery(field, termStr);
	}

	@Override
	protected Query getWildcardQuery(String field, String termStr) throws ParseException {
		if (field == null) {
			List<Query> clauses = new ArrayList<>();
			for (String f : fields) {
				clauses.add(getWildcardQuery(f, termStr));
			}
			return getMultiFieldQuery(clauses);
		}
		return super.getWildcardQuery(field, termStr);
	}

	@Override
	protected Query getRangeQuery(String field, String part1, String part2, boolean startInclusive, boolean endInclusive) throws ParseException {
		if (field == null) {
			List<Query> clauses = new ArrayList<>();
			for (String f : fields) {
				clauses.add(getRangeQuery(f, part1, part2, startInclusive, endInclusive));
			}
			return getMultiFieldQuery(clauses);
		}
		return super.getRangeQuery(field, part1, part2, startInclusive, endInclusive);
	}

	@Override
	protected Query getRegexpQuery(String field, String termStr) throws ParseException {
		if (field == null) {
			List<Query> clauses = new ArrayList<>();
			for (String f : fields) {
				clauses.add(getRegexpQuery(f, termStr));
			}
			return getMultiFieldQuery(clauses);
		}
		return super.getRegexpQuery(field, termStr);
	}

	/** Creates a multifield query */
	// TODO: investigate more general approach by default, e.g. DisjunctionMaxQuery?
	protected Query getMultiFieldQuery(List<Query> queries) throws ParseException {
		if (queries.isEmpty()) {
			return null; // all clause words were filtered away by the analyzer.
		}

		if (dismax) {
			return new DisjunctionMaxQuery(queries, dismaxTie);
		}
		else {
			//mdavis - don't use super method because of min match
			BooleanQuery.Builder query = new BooleanQuery.Builder();
			for (Query sub : queries) {
				query.add(sub, BooleanClause.Occur.SHOULD);
			}

			return query.build();
		}
	}

}