/*
 * Seldon -- open source prediction engine
 * =======================================
 *
 * Copyright 2011-2015 Seldon Technologies Ltd and Rummble Ltd (http://www.seldon.io/)
 *
 * ********************************************************************************************
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *        http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 * ********************************************************************************************
 */

package io.seldon.sv;

import io.seldon.recommendation.RecommendationUtils;
import io.seldon.semvec.QueryTransform;
import io.seldon.semvec.SemVectorResult;
import io.seldon.semvec.VectorStorePredictor;
import io.seldon.semvec.VectorStoreRecommender;
import io.seldon.util.CollectionTools;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Timer;

import org.apache.log4j.Logger;

import pitt.search.semanticvectors.FlagConfig;
import pitt.search.semanticvectors.LuceneUtils;
import pitt.search.semanticvectors.SearchResult;
import pitt.search.semanticvectors.VectorSearcher;
import pitt.search.semanticvectors.VectorStore;
import pitt.search.semanticvectors.VectorStoreRAM;
import pitt.search.semanticvectors.vectors.Vector;
import pitt.search.semanticvectors.vectors.ZeroVectorException;

public class SemanticVectorsStore {
	
	private static Logger logger = Logger.getLogger( SemanticVectorsStore.class.getName() );
	
	private VectorStore termVecReader;
	private VectorStore docVecReader;
	private LuceneUtils luceneUtils = null;
	
	String baseDir;
	String key;
	String termFilename;
	String docFilename;
	Timer reloadTimer; // The timer to be used to reload term/doc databases
	boolean useRamStores = true;
	String[] predictionSuffixes = new String[] {"1","2","3","4","5" };
	FlagConfig flagConfig = FlagConfig.getFlagConfig(null);
	
	public SemanticVectorsStore(VectorStoreRAM termStore,VectorStoreRAM docStore)
	{
		this.termVecReader = termStore;
		this.docVecReader = docStore;
		this.useRamStores = true;
	}
	
	
	
	
	public <T extends Comparable<T>>void searchTermsUsingTermQuery(T termQuery,ArrayList<SemVectorResult<T>> docResult,QueryTransform<T> termTransform,int numResults)
	{
		String query = termTransform.toSV(termQuery);
		LinkedList<SearchResult> results = search(query,termVecReader,termVecReader,numResults);
		for(SearchResult r : results)
		{
			String filename = r.getObjectVector().getObject().toString();
			docResult.add(new SemVectorResult<>(termTransform.fromSV(filename),r.getScore()));
		}
	}
	
	
	public <T extends Comparable<T>,L extends Comparable<L>>void searchDocsUsingTermQuery(L termQuery,ArrayList<SemVectorResult<T>> docResult,QueryTransform<T> docTransform,QueryTransform<L> termTransform,int numResults)
	{
		String query = termTransform.toSV(termQuery); 
		LinkedList<SearchResult> results = search(query,termVecReader,docVecReader,numResults);
		for(SearchResult r : results)
		{
			String filename = r.getObjectVector().getObject().toString();
			docResult.add(new SemVectorResult<>(docTransform.fromSV(filename),r.getScore()));
		}
	}
	
	
	public <T extends Comparable<T>,L extends Comparable<L>>void recommendDocsUsingTermQuery(L termQuery,ArrayList<SemVectorResult<T>> docResult,QueryTransform<T> docTransform,QueryTransform<L> termTransform,int numResults,Set<T> exclusions,Set<T> inclusions,T minDoc)
	{
		String query = termTransform.toSV(termQuery); 
		Set<String> docExclusions = new HashSet<>();
		if (exclusions != null)
			for(T i : exclusions)
				docExclusions.add(docTransform.toSV(i));
		Set<String> docInclusions = new HashSet<>();
		if (inclusions != null)
			for(T i : inclusions)
				docInclusions.add(docTransform.toSV(i));
		LinkedList<SearchResult> results = recommend(query,termVecReader,docVecReader,numResults,docExclusions,docInclusions,docTransform.toSV(minDoc));
		for(SearchResult r : results)
		{
			String filename = r.getObjectVector().getObject().toString();
			docResult.add(new SemVectorResult<>(docTransform.fromSV(filename),r.getScore()));
		}
	}
	
	
	public Double predict(long user,long item)
	{
		VectorStorePredictor predictor = new VectorStorePredictor(""+user,termVecReader,docVecReader,null);
		String prediction = predictor.getPrediction(""+item, predictionSuffixes);
		if (prediction != null)
			return Double.parseDouble(prediction);
		else
			return null;
	}

	
	
	public <T extends Comparable<T>>void recommendDocsUsingDocQuery(T docQuery,ArrayList<SemVectorResult<T>> docResult,QueryTransform<T> docTransform,int numResults,Set<T> exclusions,T minDoc)
	{
		String docName = docTransform.toSV(docQuery);
		Set<String> docExclusions = new HashSet<>();
		for(T i : exclusions)
			docExclusions.add(docTransform.toSV(i));
		LinkedList<SearchResult> results = recommend(docName,docVecReader,docVecReader,numResults,docExclusions,new HashSet<String>(),docTransform.toSV(minDoc));
		for(SearchResult r : results)
		{
			String filename = r.getObjectVector().getObject().toString();
			docResult.add(new SemVectorResult<>(docTransform.fromSV(filename),r.getScore()));
		}
	}
	
	
	public <T extends Comparable<T>>void searchDocsUsingDocQuery(T termQuery,ArrayList<SemVectorResult<T>> docResult,QueryTransform<T> docTransform,int numResults)
	{
		String docName = docTransform.toSV(termQuery);
		LinkedList<SearchResult> results = search(docName,docVecReader,docVecReader,numResults);
		for(SearchResult r : results)
		{
			String filename = r.getObjectVector().getObject().toString();
			docResult.add(new SemVectorResult<>(docTransform.fromSV(filename),r.getScore()));
		}
	}
	
	
	
	private LinkedList<SearchResult> search(String query,VectorStore queryStore,VectorStore searchStore,int numResults)
	{
		 VectorSearcher vecSearcher;
		 LinkedList<SearchResult> results = new LinkedList<>();
		 try 
		 {
			 String[] queryTerms = query.split("\\s+");
			 vecSearcher =
		            new VectorSearcher.VectorSearcherCosine(queryStore,
		                                                    searchStore,
		                                                    luceneUtils,
		                                                    flagConfig,
		                                                    queryTerms);
			 results = vecSearcher.getNearestNeighbors(numResults);
			 
			 
		 } catch (pitt.search.semanticvectors.vectors.ZeroVectorException e) {
			 results = new LinkedList<>();

		}
		 return results;
	}
	
	
	private LinkedList<SearchResult> recommend(String query,VectorStore queryStore,VectorStore searchStore,int numResults,Set<String> exclusions,Set<String> inclusions,String minDoc)
	{
	
		 LinkedList<SearchResult> results = new LinkedList<>();
		 try 
		 {
			 String[] queryTerms = query.split("\\s+");
			 VectorStoreRecommender vecSearcher =
		            new VectorStoreRecommender.VectorStoreRecommenderCosine(queryStore,
		                                                    searchStore,
		                                                    luceneUtils,
		                                                    queryTerms,
		                                                    exclusions,
		                                                    inclusions,
		                                                    minDoc);
			 results = vecSearcher.getNearestNeighbors(numResults);
			 
			 
		 } catch (ZeroVectorException e) {
			 results = new LinkedList<>();
		}
		 finally{}
		 return results;
	}
	
	

	
	public <T extends Comparable<T>> double getSimilarity(T a,T b,QueryTransform<T> docTransform)
	{
		String aDoc = docTransform.toSV(a);
		String bDoc = docTransform.toSV(b);
		Vector vectorA = docVecReader.getVector(aDoc);
		Vector vectorB = docVecReader.getVector(bDoc);
		if (vectorA != null && vectorB != null)
			return vectorA.measureOverlap(vectorB);
		else
			return 0D;
	}
	
	public <T extends Comparable<T>> List<T> sortDocsUsingDocQuery(List<T> recentItems,List<T> sortItems,QueryTransform<T> docTransform)
	{
		return sortDocsUsingDocQuery(recentItems, sortItems, docTransform, new HashSet<T>());
	}
	
	/**
	 * Sort a set of items based on similarity with a list of items
	 * @param <T>
	 * @param recentItems
	 * @param sortItems
	 * @param docTransform
	 * @return
	 */
	public <T extends Comparable<T>> List<T> sortDocsUsingDocQuery(List<T> recentItems,List<T> sortItems,QueryTransform<T> docTransform,Set<T> exclusions)
	{
		//various hardwired algorithms - not yet exposed in settings as this is early stage testing and we
		// may only use the best one
		boolean useRank = false;
		boolean bestScore = false;
		boolean useThreshold = false;
		double threshold = 0.999;
		List<T> result = new ArrayList<>();
		Map<Vector,T> sortVectors = new HashMap<>();
		Map<Vector,Double> scores = new HashMap<>();
		boolean comparisonsMade = false;
		boolean foundItemsToSort = false;
		List<T> alreadySeen = new ArrayList<>();
		List<T> notFound = new ArrayList<>();
		for(T item : sortItems)
		{
			if (!recentItems.contains(item) && !exclusions.contains(item))
			{
				Vector v = docVecReader.getVector(docTransform.toSV(item));
				if (v != null && !v.isZeroVector())
				{
					foundItemsToSort = true;
					sortVectors.put(v,item);
					scores.put(v, 0D);
				}
				else
				{
					notFound.add(item);
					logger.warn("Can't find vector for sort item "+item);
				}
			}
			else
			{
				if (logger.isDebugEnabled())
					logger.debug("Not sorting already seen article "+item);
				alreadySeen.add(item);
			}
		}
		if (!foundItemsToSort)
		{
			logger.debug("No sort items so returning empty list");
			return new ArrayList<>();
		}
		for(T recent : recentItems)
		{
			if (logger.isDebugEnabled())
				logger.debug("Recent item " + recent);
			String recentDoc = docTransform.toSV(recent);
			Vector vectorRecent = docVecReader.getVector(recentDoc);
			if (vectorRecent != null && !vectorRecent.isZeroVector())
			{
				comparisonsMade = true;
				if (useRank)
				{
					Map<Vector,Double> scoresLocal = new HashMap<>();
					for(Map.Entry<Vector, T> e : sortVectors.entrySet())
						scoresLocal.put(e.getKey(), vectorRecent.measureOverlap(e.getKey()));
					List<Vector> orderedLocal = CollectionTools.sortMapAndLimitToList(scoresLocal, scoresLocal.size());
					double count = 1;
					for(Vector vOrdered : orderedLocal)
					{
						scores.put(vOrdered, scores.get(vOrdered)+count);
						count++;
					}
				}
				else
				{
					for(Map.Entry<Vector, T> e : sortVectors.entrySet())
					{
						double overlap = vectorRecent.measureOverlap(e.getKey());
						double current = scores.get(e.getKey());
						if (!Double.isNaN(overlap))
						{
							if (logger.isDebugEnabled())
								logger.debug("Overlap with "+e.getValue()+" is "+overlap);
							if (bestScore) // just store best score
							{
								if (overlap > current)
									scores.put(e.getKey(), overlap);
							}
							else
							{
								if (useThreshold) // only add scores for high threshold matches
								{
									if (current < threshold && overlap > current)
										scores.put(e.getKey(),overlap);
									else if (current > threshold && overlap > threshold)
										scores.put(e.getKey(),overlap+current);
								}
								else // add all scores together good or bad
									scores.put(e.getKey(),overlap+current);
							}
						}
					}
				}
			}
			else
				logger.warn("Can't get vector for recent item "+recent);
		}
		if (comparisonsMade)
		{
			List<Vector> ordered;
			if (useRank)
				ordered = CollectionTools.sortMapAndLimitToList(scores, scores.size(),false);
			else
				ordered = CollectionTools.sortMapAndLimitToList(scores, scores.size());
			for(Vector vOrdered : ordered)
			{
				if (logger.isDebugEnabled())
					logger.debug("Item " + sortVectors.get(vOrdered) + " has score " +  scores.get(vOrdered));
				result.add(sortVectors.get(vOrdered));
			}
			for(T seenItem : alreadySeen)
			{
				if (logger.isDebugEnabled())
					logger.debug("Adding already seen item "+seenItem+" to end of list");
				result.add(seenItem);
			}
			for(T notFoundItem : notFound)
			{
				if (logger.isDebugEnabled())
					logger.debug("Adding not found item "+notFoundItem+" to end of list");
				result.add(notFoundItem);
			}
			return result;
		}
		else
		{
			logger.debug("No comparisons made so returning empty list");
			return new ArrayList<>();
		}
	}
	
	
	
	public <T extends Comparable<T>> Map<T,Double> recommendDocsUsingDocQuery(List<T> recentItems,Set<T> sortItems,QueryTransform<T> docTransform,
																			  int numRecommendations,boolean ignorePerfectMatches)
	{
		List<T> result = new ArrayList<>();
		Map<Vector,T> sortVectors = new HashMap<>();
		Map<Vector,Double> scores = new HashMap<>();
		boolean comparisonsMade = false;
		boolean foundItemsToSort = false;
		List<T> alreadySeen = new ArrayList<>();
		List<T> notFound = new ArrayList<>();
		for(T item : sortItems)
		{
			if (!recentItems.contains(item))
			{
				Vector v = docVecReader.getVector(docTransform.toSV(item));
				if (v != null && !v.isZeroVector())
				{
					foundItemsToSort = true;
					sortVectors.put(v,item);
					scores.put(v, 0D);
				}
				else
				{
					notFound.add(item);
					//logger.warn("Can't find vector for sort item "+item);
				}
			}
			else
			{
				if (logger.isDebugEnabled())
					logger.debug("Not sorting already seen article "+item);
				alreadySeen.add(item);
			}
		}
		if (!foundItemsToSort)
		{
			logger.debug("No sort items so returning empty list");
			return new HashMap<>();
		}
		for(T recent : recentItems)
		{
			if (logger.isDebugEnabled())
				logger.debug("Recent item " + recent);
			String recentDoc = docTransform.toSV(recent);
			Vector vectorRecent = docVecReader.getVector(recentDoc);
			if (vectorRecent != null && !vectorRecent.isZeroVector())
			{
				comparisonsMade = true;
				for(Map.Entry<Vector, T> e : sortVectors.entrySet())
				{
					double overlap = vectorRecent.measureOverlap(e.getKey());
					double current = scores.get(e.getKey());
					if (!Double.isNaN(overlap))
					{
						//logger.debug("Overlap with "+e.getValue()+" is "+overlap);
						if (ignorePerfectMatches && overlap == 1.0)
							logger.info("Ignoring perfect match between "+recent+" and "+e.getValue()+" overlap "+overlap);
						else
							scores.put(e.getKey(),overlap+current);
					}
				}
			}
			else
				logger.warn("Can't get vector for recent item "+recent);
		}
		if (comparisonsMade)
		{
			Map<T,Double> scoresRes = new HashMap<>();
			for(Map.Entry<Vector, Double> e : scores.entrySet())
				scoresRes.put(sortVectors.get(e.getKey()), e.getValue());
			return RecommendationUtils.rescaleScoresToOne(scoresRes, numRecommendations);
		}
		else
		{
			logger.debug("No comparisons made so returning empty list");
			return new HashMap<>();
		}
	}
	
	
	/**
	 * Recommend a set of documents based on some recent viewed documents
	 * @param <T>
	 * @param recentItems : items to use as comparison
	 * @param docTransform : transform an internal id to SV doc
	 * @param numResults : max number of recommendations to return
	 * @param exclusions : items to exclude from returned recommendations
	 * @return
	 */
	//General recommendations
	public <T extends Comparable<T>> Map<T,Double> recommendDocsUsingDocQuery(List<T> recentItems,QueryTransform<T> docTransform,int numResults,Set<T> exclusions,T minDoc,boolean ignorePerfectMatches)
	{
		Map<T,Double> scores = new HashMap<>();
		
		for(T recent : recentItems)
		{
			ArrayList<SemVectorResult<T>> docResult = new ArrayList<>();
			recommendDocsUsingDocQuery(recent,docResult,docTransform,numResults*10,exclusions,minDoc);
			for(SemVectorResult<T> r : docResult)
			{
				Double score = scores.get(r.result);
				if (ignorePerfectMatches && r.score == 1.0)
					logger.info("Ignoring perfect match between "+recent+" and "+r.result+" overlap "+r.score);
				else
				{
					if (score != null)
						score = score + r.score;
					else
						score = r.score;
					scores.put(r.result, score);
				}

			}
		}

		return RecommendationUtils.rescaleScoresToOne(scores, numResults);
		
	}
	
	
	
	
	/**
	 * Find similar users by querying the docstore using a query from the terms passed in
	 * @param <T>
	 * @param terms
	 * @param lUtils : lucene utils
	 * @param numResults : max number of results to return
	 * @param docResult : the result list of return ids T
	 * @param docTransform : the transform from document to return id type T
	 */
	public <T extends Comparable<T>> void findSimilarUsersFromTerms(String[] terms,LuceneUtils lUtils,int numResults,ArrayList<SemVectorResult<T>> docResult,QueryTransform<T> docTransform)
	{
		List<SearchResult> results;
		try 
		{
			VectorSearcher vecSearcher =
		            new VectorSearcher.VectorSearcherCosine(termVecReader,
		                                                    docVecReader,
		                                                    luceneUtils,
		                                                    flagConfig,
		                                                    terms);
			results = vecSearcher.getNearestNeighbors(numResults);
		} 
		catch (pitt.search.semanticvectors.vectors.ZeroVectorException e) {
			results = new LinkedList<>();
		}
		for(SearchResult r : results)
		{
			String filename = r.getObjectVector().getObject().toString();
			
			docResult.add(new SemVectorResult<>(docTransform.fromSV(filename),r.getScore()));
		}
	}
	
	public static void main(String[] args)
	{
		/*
		SemVectorsTweetPeer p = new SemVectorsTweetPeer("/home/rummble/data/twitter/semvectors","termvectors.bin","docvectors.bin");
		p.reloadVectorStoresIfPossible("/home/rummble/data/twitter/semvectors","termvectors2.bin","docvectors2.bin");
		ArrayList<SemVectorResult<Long>> res = new ArrayList<SemVectorResult<Long>>();
		p.searchDocsUsingDocQuery(28442278L,res,new DocumentIdTransform());
		for(SemVectorResult<Long> r : res)
			System.out.println(""+r.getResult()+":"+r.getScore());
			*/
		
		/*
		SemVectorsPeer p = new SemVectorsPeer("/home/rummble/data/twitter/football/hashtags/semvectors","t","termvectors2.bin","docvectors2.bin",true);
		ArrayList<SemVectorResult<Long>> res = new ArrayList<SemVectorResult<Long>>();
		p.searchTermsUsingTermQuery(53056865L, res, new LongIdTransform(),10);
		for(SemVectorResult<Long> r : res)
			System.out.println(""+r.getResult()+":"+r.getScore());
			*/
		
		String a = "docs/0000/1234";
		String b = "docs/0000/1235";
		if (a.compareTo(b) < 0)
			System.out.println("less than");
		else
			System.out.println("greater than");
	}

	public VectorStore getTermVecReader() {
		return termVecReader;
	}

	public VectorStore getDocVecReader() {
		return docVecReader;
	}
	
	
}