package processing.musicrec;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Timer;
import java.util.TreeMap;
import java.util.concurrent.TimeUnit;

import com.google.common.base.Stopwatch;
import com.google.common.primitives.Ints;

import common.DoubleMapComparator;
import common.MapUtil;
import common.Bookmark;
import common.MemoryThread;
import common.PerformanceMeasurement;
import common.Utilities;
import file.PredictionFileWriter;
import processing.BLLCalculator;
import file.BookmarkReader;

public class MusicCFRecommender {
	
	public static int MAX_NEIGHBORS = 20;
	
	private BookmarkReader reader;
	private List<Bookmark> trainList;
	private List<Bookmark> testList;
	private List<Map<Integer, Double>> userMaps;
	private Map<String, Double> simMap;
	private List<Map<Integer, Double>> bllMap;
	private Double beta;
	
	public MusicCFRecommender(BookmarkReader reader, int trainSize, Double bllVal, Double beta, String type) {
		this.reader = reader;
		this.beta = beta;
		this.trainList = this.reader.getBookmarks().subList(0, trainSize);
		this.testList = this.reader.getBookmarks().subList(trainSize, this.reader.getBookmarks().size());
		this.simMap = new LinkedHashMap<String, Double>();
		
		if (bllVal != null) {
			this.bllMap = BLLCalculator.getArtifactMaps(this.reader, this.trainList, this.testList, false, 
					new ArrayList<Long>(), new ArrayList<Double>(), bllVal.doubleValue(), true, null, true);
		}
		if (type.equals("general")) {
			this.userMaps = Utilities.getFloatUserMaps(this.trainList);
			System.out.println("Avg. number of genres per user: " + this.getAvgUserSize());
		} else if (type.equals("pop")) {
			this.userMaps = getTopUserArtists(30);
		} else { // time
			this.userMaps = getRecentUserArtists(30);
		}
	}
	
	private List<Map<Integer, Double>> getTopUserArtists(int limit) {
		List<Map<Integer, Double>> returnList = new ArrayList<Map<Integer, Double>>();
		for (Map<Integer, Integer> uMap : Utilities.getUserMaps(this.trainList)) {
			Map<Integer, Double> userArtists = new LinkedHashMap<Integer, Double>();
			for (Map.Entry<Integer, Integer> artist: MapUtil.sortByValue(uMap).entrySet()) {
				if (userArtists.size() < limit) {
					userArtists.put(artist.getKey(), (double)artist.getValue());
				} else {
					break;
				}
			}
			returnList.add(userArtists);
		}
		return returnList;
	}
	
	private List<Map<Integer, Double>> getRecentUserArtists(int limit) {
		List<Map<Integer, Double>> returnList = new ArrayList<Map<Integer, Double>>();
		List<List<Bookmark>> userBookmarks = Utilities.getBookmarks(this.trainList, false);
		for (List<Bookmark> bookmarks : userBookmarks) {
			Map<Integer, Double> uMap = new LinkedHashMap<Integer, Double>();
			int count = 0;
			int index = bookmarks.size() - 1;
			while (count < limit && index >= 0) {
				Bookmark b = bookmarks.get(index--);
				for (int t : b.getTags()) {
					if (!uMap.containsKey(t)) {
						uMap.put(t, (double)(limit - count++));
						if (count >= limit) {
							break;
						}
					}
				}
			}
			returnList.add(uMap);
		}	
		return returnList;
	}
	
	private double getAvgPairwiseSim() {
		double sim = 0.0;
		if (this.simMap.size() == 0) {
			return sim;
		}
		for (double s : this.simMap.values()) {
			sim += s;
		}
		return sim / this.simMap.size();
	}
	
	private double getAvgUserSize() {
		double size = 0.0;
		if (this.userMaps.size() == 0) {
			return size;
		}
		for (Map<Integer, Double> map : this.userMaps) {
			size += map.size();
		}
		return size / this.userMaps.size();
	}
	
	private void printPairwiseSim(String filename) {
		System.out.println("Entries to write: " + this.simMap.size());
		String fileToWrite = "./data/metrics/" + filename + "_cosine_sim.txt";
		//String fileToWrite = "./data/metrics/" + filename + "_jaccard_sim.txt";       
		try {
            FileWriter writer = new FileWriter(new File(fileToWrite));
            BufferedWriter bw = new BufferedWriter(writer);
			for (Map.Entry<String, Double> entry : this.simMap.entrySet()) {
				bw.write(entry.getKey() + ";");
				bw.write(entry.getValue() + "\n");
			}
			bw.flush();
			bw.close();
        } catch (IOException e) {
            e.printStackTrace();
        }     
	}
	
	private double getBetaForUser(Bookmark data) {
		if (this.beta < 0) {
			if (data.getRating() > 0) {
				return data.getRating();
			} else {
				System.out.println("Wrong rating");
				return 0.5;
			}
		} else {
			return this.beta.doubleValue();
		}
	}
	
	public Map<Integer, Double> getRankedTagList(Bookmark data, boolean sorting) {
		int userID = data.getUserID();
		Map<Integer, Double> resultMap = new LinkedHashMap<Integer, Double>();
		int i = 0;
		
		Map<Integer, Double> targetUserMap = null;
		if (Utilities.FILTER_OWN) {
			targetUserMap = this.userMaps.get(userID);
		} else {
			targetUserMap = new LinkedHashMap<Integer, Double>();
		}
		Map<Integer, Double> neighbors = getNeighbors(userID);
		for (Map.Entry<Integer, Double> entry : neighbors.entrySet()) {
			if (i++ < MAX_NEIGHBORS) {
				Map<Integer, Double> neighborMap = this.userMaps.get(entry.getKey());
				double simVal = entry.getValue().doubleValue();
				
				for (Map.Entry<Integer, Double> artist : neighborMap.entrySet()) {
					double artistSimVal = simVal * artist.getValue().doubleValue();
					Double val = resultMap.get(artist.getKey());
					resultMap.put(artist.getKey(), (val != null ? val.doubleValue() + artistSimVal : artistSimVal));
				}
			} else {
				break;
			}
		}
		if (this.beta != null && userID < this.bllMap.size()) { // hybrid!
			double beta = getBetaForUser(data);
			MapUtil.normalizeMap(resultMap, beta);
			Map<Integer, Double> userBllMap = this.bllMap.get(userID);
			for (Map.Entry<Integer, Double> artist : userBllMap.entrySet()) {
				double bllval = artist.getValue().doubleValue() * (1.0 - beta);
				Double cfval = resultMap.get(artist.getKey());				
				double hybridval = (cfval != null ? cfval.doubleValue() + bllval : bllval);
				resultMap.put(artist.getKey(), hybridval);
			}
		}
		if (sorting) {
			Map<Integer, Double> sortedResultMap = MapUtil.sortByValue(resultMap);
			//Map<Integer, Double> sortedResultMap = new TreeMap<Integer, Double>(new DoubleMapComparator(resultMap));
			//sortedResultMap.putAll(resultMap);			
			Map<Integer, Double> returnMap = new LinkedHashMap<Integer, Double>();
			for (Map.Entry<Integer, Double> entry : sortedResultMap.entrySet()) {
				if (returnMap.size() < Utilities.REC_LIMIT) {
					if (!targetUserMap.containsKey(entry.getKey())) {
						returnMap.put(entry.getKey(), entry.getValue().doubleValue());
					}
				} else {
					break;
				}
			}
			return returnMap;
		}
		return resultMap;
	}
	
	private Map<Integer, Double> getNeighbors(int userID) {
		Map<Integer, Double> neighbors = new LinkedHashMap<Integer, Double>();
		// get all users 
		for (Bookmark data : this.testList) {
			if (data.getUserID() != userID) {
				neighbors.put(data.getUserID(), 0.0);
			}
		}
		
		if (userID < this.userMaps.size()) {
			Map<Integer, Double> targetMap = this.userMaps.get(userID);			
			for (Map.Entry<Integer, Double> entry : neighbors.entrySet()) {
				double simVal = 0.0;
				String simID = (userID < entry.getKey() ? userID + ";" + entry.getKey() : entry.getKey() + ";" + userID);
				if (this.simMap.containsKey(simID)) {
					simVal = this.simMap.get(simID);
				} else {
					simVal = Utilities.getCosineFloatSim(targetMap, this.userMaps.get(entry.getKey()));
					//simVal = Utilities.getJaccardFloatSim(targetMap, this.userMaps.get(entry.getKey()));
					this.simMap.put(simID, simVal);
				}
				entry.setValue(simVal);
			}
			
			Map<Integer, Double> sortedNeighbors = MapUtil.sortByValue(neighbors);
			//Map<Integer, Double> sortedNeighbors = new TreeMap<Integer, Double>(new DoubleMapComparator(neighbors));
			//sortedNeighbors.putAll(neighbors);
			return sortedNeighbors;
		}
		System.out.println("Wrong user id");
		return neighbors;
	}
	
	// Statics --------------------------------------------------------------------------------------------------------
	
	private static List<Map<Integer, Double>> startCollaborativeFiltering(BookmarkReader reader, int sampleSize, String filename, Double bllVal, Double beta, String type) {
		int size = reader.getBookmarks().size();
		int trainSize = size - sampleSize;

		MusicCFRecommender calculator = new MusicCFRecommender(reader, trainSize, bllVal, beta, type);		
		List<Map<Integer, Double>> results = new ArrayList<Map<Integer, Double>>();
		for (int i = trainSize; i < size; i++) {
			Bookmark data = reader.getBookmarks().get(i);
			Map<Integer, Double> map = null;
			map = calculator.getRankedTagList(data, true);
			results.add(map);
		}
		System.out.println("Average pairwise sim: " + calculator.getAvgPairwiseSim());
		calculator.printPairwiseSim(filename);
		
		return results;
	}	
	
	public static BookmarkReader predictTags(String filename, int trainSize, int sampleSize, int neighbors, Double bllVal, Double beta, String type) {
		MAX_NEIGHBORS = neighbors;
		return predictSample(filename, trainSize, sampleSize, bllVal, beta, type);
	}
	
	public static BookmarkReader predictSample(String filename, int trainSize, int sampleSize, Double bllVal, Double beta, String type) {
		BookmarkReader reader = new BookmarkReader(trainSize, false);
		reader.readFile(filename);
		
		/*
		Utilities.REC_LIMIT = 100;		
		FileWriter pywriter;
		try {
			pywriter = new FileWriter(new File("./data/metrics/" + filename + "_cfvals.txt"));
			BufferedWriter pybw = new BufferedWriter(pywriter);
			for (Map<Integer, Double> map : cfValues) {
				//MapUtil.normalizeMap(map);
				boolean firstEntry = true;
				for (Map.Entry<Integer, Double> entry : map.entrySet()) {
					if (!firstEntry) {
						pybw.write(";");
					} else {
						firstEntry = false;
					}
					pybw.write(reader.getResources().get(entry.getKey()) + ":" + entry.getValue());
				}
				pybw.write("\n");
			}
			pybw.flush();
			pybw.close();
		} catch (IOException e) {
			e.printStackTrace();
		}
		*/
		List<Map<Integer, Double>> cfValues = startCollaborativeFiltering(reader, sampleSize, filename, bllVal, beta, type);
		
		List<int[]> predictionValues = new ArrayList<int[]>();
		for (int i = 0; i < cfValues.size(); i++) {
			Map<Integer, Double> modelVal = cfValues.get(i);
			predictionValues.add(Ints.toArray(modelVal.keySet()));
		}		
		String suffix = "_cf_" + type + "_";
		if (beta != null) {
			suffix = "_bll_cf_";
		}
        if (beta != null && beta > 0) {
        	suffix += "static_";
        }
		reader.setTestLines(reader.getBookmarks().subList(trainSize, reader.getBookmarks().size()));
		PredictionFileWriter writer = new PredictionFileWriter(reader, predictionValues);
		String outputFile = filename + suffix + MAX_NEIGHBORS;
		writer.writeFile(outputFile);

		return reader;
	}
}