package doser.word2vec.semanticCategories;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;

import org.apache.log4j.Logger;
import org.jgrapht.Graph;
import org.jgrapht.UndirectedGraph;
import org.jgrapht.alg.DijkstraShortestPath;
import org.jgrapht.graph.AbstractBaseGraph;
import org.jgrapht.graph.ClassBasedEdgeFactory;
import org.jgrapht.graph.DefaultEdge;

import com.hp.hpl.jena.query.QueryException;
import com.hp.hpl.jena.query.QueryExecution;
import com.hp.hpl.jena.query.QueryExecutionFactory;
import com.hp.hpl.jena.query.QueryFactory;
import com.hp.hpl.jena.query.QuerySolution;
import com.hp.hpl.jena.query.ResultSet;
import com.hp.hpl.jena.rdf.model.Model;
import com.hp.hpl.jena.rdf.model.ModelFactory;
import com.hp.hpl.jena.rdf.model.Property;
import com.hp.hpl.jena.rdf.model.RDFNode;
import com.hp.hpl.jena.rdf.model.Resource;
import com.hp.hpl.jena.rdf.model.Statement;
import com.hp.hpl.jena.rdf.model.StmtIterator;

public class Sampling {

	public static final String CATEGORYPURITY = "/home/zwicklbauer/word2vec/MSEDbpediaCategories_Min5.txt";

	public static final int MAXIMUMSAMPLENR = 5000;

	private Graph<String, DefaultEdge> graph;

	private String[] catSet;
	private HashSet<String> catHash;

	private Model m;

	private Random random;

	public Sampling() {
		super();
		BufferedReader reader = null;
		List<String> catList = new LinkedList<String>();
		this.catHash = new HashSet<String>();
		try {
			reader = new BufferedReader(
					new FileReader(new File(CATEGORYPURITY)));
			String line = null;
			while ((line = reader.readLine()) != null) {
				String splitter[] = line.split("\t");
				double score = Double.parseDouble(splitter[0]);
				if (score < 0.033 && score > -2) {
					catList.add(splitter[1]);
					this.catHash.add(splitter[1]);
				}
			}
		} catch (FileNotFoundException e) {
			e.printStackTrace();
		} catch (IOException e) {
			e.printStackTrace();
		} finally {
			if (reader != null) {
				try {
					reader.close();
				} catch (IOException e) {
					e.printStackTrace();
				}
			}
		}
		this.catSet = new String[catList.size()];
		this.catSet = catList.toArray(this.catSet);
		this.graph = createGraph();
		this.m = ModelFactory.createDefaultModel();
		this.m.read("/home/zwicklbauer/HDTGeneration/article_categories_en.nt");
		this.random = new Random();
		System.out.println(this.catHash.size());
	}

	public Map<Integer, HashSet<EntityPair>> generateCandidates(int maxDistance) {
		ConcurrentMap<Integer, HashSet<EntityPair>> map = new ConcurrentHashMap<Integer, HashSet<EntityPair>>();
		for (int i = 0; i < maxDistance; i++) {
			int stepsize = i;
			int counter = 0;
			while (true) {
				String randomCat = pickCategory();
				String basicCat = randomCat;
				String e1 = queryEntitiesFromCategory(randomCat);
				if (e1 != null) {
					boolean foundRelevantCategory = false;
					while (!foundRelevantCategory) {
						// Choose random category
						randomCat = performRandomStep(randomCat);
						if (catHash.contains(randomCat)) {
							foundRelevantCategory = true;
						}
					}

					if (randomCat != null) {
						String e2 = queryEntitiesFromCategory(randomCat);
						if (e2 != null) {
							List path = DijkstraShortestPath.findPathBetween(
									graph, basicCat, randomCat);
							if (path.size() == stepsize
									&& !e1.equalsIgnoreCase(e2)) {
								if (map.containsKey(stepsize)) {
									HashSet<EntityPair> set = map.get(stepsize);
									set.add(new EntityPair(e1, e2, basicCat,
											randomCat));
									counter++;
								} else {
									HashSet<EntityPair> set = new HashSet<EntityPair>();
									set.add(new EntityPair(e1, e2, basicCat,
											randomCat));
									map.put(stepsize, set);
									counter++;
								}
							}
						}
					}
				}
				System.out.println(counter);
				
				if (counter == MAXIMUMSAMPLENR) {
					break;
				}
			}
		}
		return map;
	}

	private String performRandomStep(String current) {
		if (current == null) {
			return null;
		}
		String jumpstep = null;
		Set<DefaultEdge> edges = graph.edgesOf(current);
		int max = edges.size();
		int ran = random.nextInt(max);
		int counter = 0;
		for (DefaultEdge e : edges) {
			if (ran == counter) {
				String source = graph.getEdgeSource(e);
				String target = graph.getEdgeTarget(e);
				if (source.equalsIgnoreCase(current)) {
					jumpstep = target;
				} else {
					jumpstep = source;
				}
				break;
			}
			counter++;
		}
		return jumpstep;
	}

	private String pickCategory() {
		int index = random.nextInt(catSet.length);
		return catSet[index];
	}

	public UndirectedGraph<String, DefaultEdge> createGraph() {
		Model model = ModelFactory.createDefaultModel();
		model.read("/home/zwicklbauer/HDTGeneration/skos_categories_en.nt");
		StmtIterator it = model.listStatements();
		UndirectedGraph<String, DefaultEdge> graph = new MiGrafo();
		Set<String> set = new HashSet<String>();

		int counter = 0;
		while (it.hasNext()) {
			Statement s = it.next();
			Resource r = s.getSubject();
			Property p = s.getPredicate();
			RDFNode n = s.getObject();
			if (p.getURI().equalsIgnoreCase(
					"http://www.w3.org/2004/02/skos/core#broader")
					&& n.isResource()) {
				set.add(r.getURI());
				Resource target = n.asResource();
				set.add(target.getURI());
				if (!graph.containsVertex(r.getURI())) {
					graph.addVertex(r.getURI());
				}
				if (!graph.containsVertex(target.getURI())) {
					graph.addVertex(target.getURI());
				}
				graph.addEdge(r.getURI(), target.getURI());
				if (counter % 10000 == 0) {
					System.out.println(counter);
				}
				counter++;
			}
		}
		return graph;
	}

	private String queryEntitiesFromCategory(final String catUri) {
		String res = null;

		final String query = "SELECT ?entities WHERE{ ?entities <http://purl.org/dc/terms/subject> <"
				+ catUri + ">. }";
		try {
			final com.hp.hpl.jena.query.Query cquery = QueryFactory
					.create(query);
			final QueryExecution qexec = QueryExecutionFactory
					.create(cquery, m);
			final ResultSet results = qexec.execSelect();
			List<String> entities = new LinkedList<String>();
			while (results.hasNext()) {
				final QuerySolution sol = results.nextSolution();
				entities.add(sol.getResource("entities").getURI());
			}
			if (entities.size() != 0) {
				int randomNr = this.random.nextInt(entities.size());
				return entities.get(randomNr);
			}

		} catch (final QueryException e) {
			Logger.getRootLogger().error(e.getStackTrace());
		}
		return res;
	}

	class MiGrafo extends AbstractBaseGraph<String, DefaultEdge> implements
			UndirectedGraph<String, DefaultEdge> {

		private static final long serialVersionUID = 1L;

		MiGrafo() {
			super(new ClassBasedEdgeFactory<String, DefaultEdge>(
					DefaultEdge.class), true, true);

		}
	}

	public static void main(String[] args) {
		Sampling sampling = new Sampling();
		Map<Integer, HashSet<EntityPair>> map = sampling.generateCandidates(4);
		File file = new File("/home/zwicklbauer/samplingoutput.dat");
		PrintWriter writer = null;
		try {
			writer = new PrintWriter(file);
			for (Map.Entry<Integer, HashSet<EntityPair>> entry : map.entrySet()) {
				Integer key = entry.getKey();
				HashSet<EntityPair> value = entry.getValue();
				for (EntityPair p : value) {
					writer.println(String.valueOf(key) + "\t" + p.getEntity1()
							+ "\t" + p.getEntity2() + "\t" + p.getCategory1()
							+ "\t" + p.getCategory2());
				}
			}
		} catch (FileNotFoundException e) {
			e.printStackTrace();
		} finally {
			if (writer != null) {
				writer.close();
			}
		}
	}

}