package processing;

import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Timer;
import java.util.TreeMap;
import java.util.concurrent.TimeUnit;

import com.google.common.base.Stopwatch;
import com.google.common.primitives.Ints;

import common.DoubleMapComparator;
import common.Bookmark;
import common.MemoryThread;
import common.PerformanceMeasurement;
import common.Utilities;
import file.PredictionFileWriter;
import file.BookmarkReader;

public class CFTagRecommender {
	
	public static int MAX_NEIGHBORS = 20;
	private final static double K1 = 1.2;
	private final static double K3 = 1.2;
	private final static double B = 0.8;
	
	private BookmarkReader reader;
	private boolean userBased;
	private boolean resBased;
	private double beta;
	private List<Bookmark> trainList;
	private List<Map<Integer, Integer>> userMaps;
	private List<Map<Integer, Integer>> resMaps;
	
	public CFTagRecommender(BookmarkReader reader, int trainSize, boolean userBased, boolean resBased, int beta) {
		this.reader = reader;
		this.userBased = userBased;
		this.resBased = resBased;
		this.beta = (double)beta / 10.0;
		this.trainList = this.reader.getBookmarks().subList(0, trainSize);
		//Collections.sort(this.trainList);
		if (this.userBased) {
			this.userMaps = Utilities.getUserMaps(this.trainList);
		}
		if (this.resBased) {
			this.resMaps = Utilities.getResMaps(this.trainList);
		}
	}
		
	public Map<Integer, Double> getRankedTagList(int userID, int resID, boolean sorting) {
		Map<Integer, Double> resultMap = new LinkedHashMap<Integer, Double>();
		int i = 0;
		
		if (this.userBased) {
			Map<Integer, Double> neighbors = getNeighbors(userID, resID);
			for (Map.Entry<Integer, Double> entry : neighbors.entrySet()) {
				if (i++ < MAX_NEIGHBORS) {
					//neighborMaps.add(this.userMaps.get(entry.getKey()));
					List<Integer> tags = Bookmark.getBookmark(this.trainList, entry.getKey(), resID).getTags();
					double bm25 = /*this.beta **/ entry.getValue();
					// add tags to resultMap
					for (int tag : tags) {
						Double val = resultMap.get(tag);
						resultMap.put(tag, (val != null ? val + bm25 : bm25));
					}
				} else {
					break;
				}
			}		
			//for (Map.Entry<Integer, Double> entry : resultMap.entrySet()) {
			//	entry.setValue(Math.log10(1 + (double)getTagFrequency(entry.getKey(), neighborMaps)) * entry.getValue());
			//}
		}
		if (this.resBased) {
			Map<Integer, Double> resources = getSimResources(userID, resID);
			for (Map.Entry<Integer, Double> entry : resources.entrySet()) {
				if (i++ < MAX_NEIGHBORS) {
					List<Integer> tags = Bookmark.getResData(this.trainList, userID, entry.getKey()).getTags();
					double bm25 = /*(1.0 - this.beta) **/ entry.getValue();
					// add tags to resultMap
					for (int tag : tags) {
						Double val = resultMap.get(tag);
						resultMap.put(tag, (val != null ? val + bm25 : bm25));
					}
				} else {
					break;
				}
			}	
		}
		
		if (sorting) {
			Map<Integer, Double> sortedResultMap = new TreeMap<Integer, Double>(new DoubleMapComparator(resultMap));
			sortedResultMap.putAll(resultMap);			
			Map<Integer, Double> returnMap = new LinkedHashMap<Integer, Double>(10);
			int index = 0;
			for (Map.Entry<Integer, Double> entry : sortedResultMap.entrySet()) {
				if (index++ < 10) {
					returnMap.put(entry.getKey(), entry.getValue());
				} else {
					break;
				}
			}
			return returnMap;
		}
		return resultMap;
	}
	
	private Map<Integer, Double> getNeighbors(int userID, int resID) {
		Map<Integer, Double> neighbors = new LinkedHashMap<Integer, Double>();
		// get all users that have tagged the resource
		for (Bookmark data : this.trainList) {
			if (data.getUserID() != userID) {
				if (resID == -1) {
					neighbors.put(data.getUserID(), 0.0);
				} else if (data.getResourceID() == resID) {
					neighbors.put(data.getUserID(), 0.0);
				}
			}
		}
		// if list is empty, use all users		
		if (neighbors.size() == 0) {
			for (Bookmark data : this.trainList) {
				neighbors.put(data.getUserID(), 0.0);
			}
		}
		
		//List<Map<Integer, Double>> tfidfMaps = getTFIDFMaps(this.userMaps);
		if (userID < this.userMaps.size()) {
			Map<Integer, Integer> targetMap = this.userMaps.get(userID);
			//Map<Integer, Double> targetMap = tfidfMaps.get(userID);
			
			for (Map.Entry<Integer, Double> entry : neighbors.entrySet()) {
				double bm25Value = Utilities.getCosineSim(targetMap, this.userMaps.get(entry.getKey()));
				//double bm25Value = Utilities.getCosineFloatSim(targetMap, tfidfMaps.get(entry.getKey()));
				entry.setValue(bm25Value);
			}
			
			// return the sorted neighbors
			Map<Integer, Double> sortedNeighbors = new TreeMap<Integer, Double>(new DoubleMapComparator(neighbors));
			sortedNeighbors.putAll(neighbors);
			return sortedNeighbors;
		}
		return neighbors;
	}
	
	private Map<Integer, Double> getSimResources(int userID, int resID) {
		Map<Integer, Double> resources = new LinkedHashMap<Integer, Double>();
		// get all resources that have been tagged by the user
		for (Bookmark data : this.trainList) {
			if (data.getResourceID() != resID) {
				if (userID == -1) {
					resources.put(data.getResourceID(), 0.0);
				} else if (data.getUserID() == userID) {
					resources.put(data.getResourceID(), 0.0);
				}
			}
		}
		// if list is empty, use all users		
		if (resources.size() == 0) {
			for (Bookmark data : this.trainList) {
				resources.put(data.getResourceID(), 0.0);
			}
		}
		
		if (resID < this.resMaps.size()) {
			Map<Integer, Integer> targetMap = this.resMaps.get(resID);
			for (Map.Entry<Integer, Double> entry : resources.entrySet()) {
				double bm25Value = Utilities.getJaccardSim(targetMap, this.resMaps.get(entry.getKey()));
				entry.setValue(bm25Value);
			}			
			// return the sorted neighbors
			Map<Integer, Double> sortedResources = new TreeMap<Integer, Double>(new DoubleMapComparator(resources));
			sortedResources.putAll(resources);
			return sortedResources;
		}
		return resources;
	}
	
	
	
	public static double getBM25Value(List<Map<Integer, Integer>> neighborMaps, double lAverage, Map<Integer, Integer> targetMap, Map<Integer, Integer> nMap) {
		double bm25Sum = 0.0;	
		for (Map.Entry<Integer, Integer> targetVal : targetMap.entrySet()) {
			double idf = getIDF(targetVal.getKey(), neighborMaps);
			double tftd = 0.0;
			if (nMap.containsKey(targetVal.getKey())) {
				tftd = (double)nMap.get(targetVal.getKey());
			}
			double tftq = (double)targetVal.getValue();
			double ld = Utilities.getMapCount(nMap);
			double bm25Val = (idf * (K1 + 1) * tftd * (K3 + 1) * tftq) /
					((K1 * ((1 - B) + B * (ld / lAverage)) + tftd) * (K3 + tftq));
			
			if (Double.valueOf(bm25Val).isNaN()) {
				System.out.println("idf: " + idf + ", tftd: " + tftd + ", tftq: " + tftq + ", ld: " + ld + ", lAverage: " + lAverage);
			}
			
			bm25Sum += bm25Val;
		}
		return bm25Sum;
	}
	
	public static int getTagFrequency(int tagID, List<Map<Integer, Integer>> neighborMaps) {
		int count = 0;
		for (Map<Integer, Integer> map : neighborMaps) {
			if (map.containsKey(tagID)) {
				count++;
			}
		}
		
		return count;
	}
	
	public static List<Map<Integer, Double>> getTFIDFMaps(List<Map<Integer, Integer>> neighborMaps) {
		List<Map<Integer, Double>> targetMaps = new ArrayList<Map<Integer, Double>>();
		for (Map<Integer, Integer> srcMap : neighborMaps) {
			Map<Integer, Double> targetMap = new LinkedHashMap<Integer, Double>();
			for (Map.Entry<Integer, Integer> srcEntry : srcMap.entrySet()) {
				targetMap.put(srcEntry.getKey(), srcEntry.getValue() * getIDF(srcEntry.getKey(), neighborMaps));
			}
			targetMaps.add(targetMap);
		}
		return targetMaps;
	}
	
	public static Map<Integer, Double> getTFIDFMap(Map<Integer, Integer> srcMap, List<Map<Integer, Integer>> neighborMaps) {
		Map<Integer, Double> targetMap = new LinkedHashMap<Integer, Double>();
		for (Map.Entry<Integer, Integer> srcEntry : srcMap.entrySet()) {
			targetMap.put(srcEntry.getKey(), srcEntry.getValue() * getIDF(srcEntry.getKey(), neighborMaps));
		}
		return targetMap;
	}
	
	public static double getIDF(int tagID, List<Map<Integer, Integer>> neighborMaps) {
		//return Math.log(((double)neighborMaps.size() - (double)count + 0.5) / ((double)count + 0.5));		
		Double idf = Math.log((double)neighborMaps.size() / (double)getTagFrequency(tagID, neighborMaps));
		return (idf.isInfinite() ? 0 : idf.doubleValue());
	}
	
	public static double getLAverage(List<Map<Integer, Integer>> neighborMaps) {
		double sum = 0.0;
		for (Map<Integer, Integer> map : neighborMaps) {
			sum += Utilities.getMapCount(map);
		}
		return sum / (double)neighborMaps.size();
	}
	
	// Statics --------------------------------------------------------------------------------------------------------
	
	private static String timeString;
	
	private static List<Map<Integer, Double>> startBM25CreationForTagPrediction(BookmarkReader reader, int sampleSize, boolean userBased, boolean resBased, int beta, boolean ignoreResource) {
		int size = reader.getBookmarks().size();
		int trainSize = size - sampleSize;
		Stopwatch timer = new Stopwatch();
		timer.start();
		CFTagRecommender calculator = new CFTagRecommender(reader, trainSize, userBased, resBased, beta);
		timer.stop();
		long trainingTime = timer.elapsed(TimeUnit.MILLISECONDS);
		
		List<Map<Integer, Double>> results = new ArrayList<Map<Integer, Double>>();
		timer.reset();
		timer.start();
		for (int i = trainSize; i < size; i++) {
			Bookmark data = reader.getBookmarks().get(i);
			Map<Integer, Double> map = null;
			int resID = (ignoreResource ? -1 : data.getResourceID());
			map = calculator.getRankedTagList(data.getUserID(), resID, true);
			results.add(map);
			//System.out.println(data.getTags() + "|" + map.keySet());
		}
		timer.stop();
		long testTime = timer.elapsed(TimeUnit.MILLISECONDS);
		
		timeString = PerformanceMeasurement.addTimeMeasurement(timeString, true, trainingTime, testTime, sampleSize);
		return results;
	}	
	
	public static BookmarkReader predictTags(String filename, int trainSize, int sampleSize, int neighbors, boolean userBased, boolean resBased, int beta, boolean ignoreResource) {
		MAX_NEIGHBORS = neighbors;
		return predictSample(filename, trainSize, sampleSize, userBased, resBased, beta, ignoreResource);
	}
	
	public static BookmarkReader predictSample(String filename, int trainSize, int sampleSize, boolean userBased, boolean resBased, int beta, boolean ignoreResource) {
		Timer timerThread = new Timer();
		MemoryThread memoryThread = new MemoryThread();
		timerThread.schedule(memoryThread, 0, MemoryThread.TIME_SPAN);
		
		BookmarkReader reader = new BookmarkReader(trainSize, false);
		reader.readFile(filename);
		
		List<Map<Integer, Double>> cfValues = null;	
		cfValues = startBM25CreationForTagPrediction(reader, sampleSize, userBased, resBased, beta, ignoreResource);
		
		List<int[]> predictionValues = new ArrayList<int[]>();
		for (int i = 0; i < cfValues.size(); i++) {
			Map<Integer, Double> modelVal = cfValues.get(i);
			predictionValues.add(Ints.toArray(modelVal.keySet()));
		}		
		String suffix = "_cf_";
		if (!userBased) {
			suffix = "_rescf_";
		} else if (!resBased) {
			suffix = "_usercf_";
		}
		reader.setTestLines(reader.getBookmarks().subList(trainSize, reader.getBookmarks().size()));
		PredictionFileWriter writer = new PredictionFileWriter(reader, predictionValues);
		String outputFile = filename + suffix + beta;
		writer.writeFile(outputFile);

		timeString = PerformanceMeasurement.addMemoryMeasurement(timeString, false, memoryThread.getMaxMemory());
		timerThread.cancel();
		Utilities.writeStringToFile("./data/metrics/" + outputFile + "_TIME.txt", timeString);
		return reader;
	}
}