package org.aksw.mlqa.experimentold;

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

import org.json.simple.JSONArray;
import org.json.simple.JSONObject;
import org.json.simple.parser.JSONParser;
import org.json.simple.parser.ParseException;
import org.jsoup.Jsoup;
import org.jsoup.nodes.Document;
import org.jsoup.nodes.Element;
import org.jsoup.select.Elements;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.common.collect.Lists;

import meka.classifiers.multilabel.RAkELd;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.converters.ArffLoader.ArffReader;

public class CDTClassifierMultilable {
	private static Logger log = LoggerFactory.getLogger(CDTClassifierMultilable.class);
	
	
	public static void main(String[] args) throws Exception {		
		/*
		 * For multilable classification:
		 */
		
		//The classifier
		RAkELd PSt_Classifier = new RAkELd();
		//load the data
		Path datapath= Paths.get("./src/main/resources/old/Qald6Logs.arff");
		BufferedReader reader = new BufferedReader(new FileReader(datapath.toString()));
		ArffReader arff = new ArffReader(reader);
		Instances data = arff.getData();
		data.setClassIndex(6);
		PSt_Classifier.buildClassifier(data);		/*
		 * Test the trained system
		 */
		
		JSONObject qald6test = loadTestQuestions();
		JSONArray questions = (JSONArray) qald6test.get("questions");
		ArrayList<String> testQuestions = Lists.newArrayList();
		for(int i = 0; i < questions.size(); i++){
			JSONObject questionData = (JSONObject) questions.get(i);
			JSONArray questionStrings = (JSONArray) questionData.get("question");
			JSONObject questionEnglish = (JSONObject) questionStrings.get(0);
			testQuestions.add((String) questionEnglish.get("string"));
		}

		
		ArrayList<String> systems = Lists.newArrayList("KWGAnswer", "NbFramework", "PersianQA", "SemGraphQA", "UIQA_withoutManualEntries", "UTQA_English" );
		double ave_f = 0;
		Double ave_bestp = 0.0;
		Double ave_bestr = 0.0;

		for(int j = 0; j < data.size(); j++){
			Instance ins = data.get(j);
			double[] confidences = PSt_Classifier.distributionForInstance(ins);
			int argmax = -1;
			double max = -1;
				for(int i = 0; i < 6; i++){
					if(confidences[i]>max){
						max = confidences[i];
						argmax = i;
					}
				}
				//compare trained system with best possible system
				
				String sys2ask = systems.get(systems.size() - argmax -1);
				float p = Float.parseFloat(loadSystemP(sys2ask).get(j));				
				float r = Float.parseFloat(loadSystemR(sys2ask).get(j));
				
				double bestp = 0;
				double bestr = 0;
				String bestSystemp = "";
				String bestSystemr = ""; 
				for(String system:systems){
					if(Double.parseDouble(loadSystemP(system).get(j)) > bestp){bestSystemp = system; bestp = Double.parseDouble(loadSystemP(system).get(j));}; 
					if(Double.parseDouble(loadSystemR(system).get(j)) > bestr){bestSystemr = system; bestr = Double.parseDouble(loadSystemR(system).get(j));}; 
					}
				ave_bestp += bestp;
				ave_bestr += bestr;
				System.out.println(testQuestions.get(j));
				System.out.println(j + "... asked " + sys2ask + " with p " + loadSystemP(sys2ask).get(j) + "... best possible p: " + bestp + " was achieved by " + bestSystemp);
				System.out.println(j + "... asked " + sys2ask + " with r " + loadSystemR(sys2ask).get(j) + "... best possible r: " + bestr + " was achieved by " + bestSystemr);
				if(p>0&&r>0){ave_f += 2*p*r/(p + r);}

				}
		System.out.println("macro F : " + ave_f/data.size());
		}
	
	public static ArrayList<String> loadSystemP(String system){

		Path datapath = Paths.get("./src/main/resources/QALD6MultilingualLogs/multilingual_" + system + ".html");
		ArrayList<String> result = Lists.newArrayList();

		try{
			String loadedData = Files.lines(datapath).collect(Collectors.joining()); 
			Document doc = Jsoup.parse(loadedData);
			Element table = doc.select("table").get(5);
			Elements tableRows = table.select("tr");
			for(Element row: tableRows){
				Elements tableEntry = row.select("td");
				result.add(tableEntry.get(2).ownText());
			}
			result.remove(0); //remove the head of the table
			return result;
		}catch(IOException e){
			e.printStackTrace();
			log.debug("loading failed.");
			return result;
		}
	}
	public static ArrayList<String> loadSystemR(String system){
		Path datapath = Paths.get("./src/main/resources/QALD6MultilingualLogs/multilingual_" + system + ".html");
		ArrayList<String> result = Lists.newArrayList();

		try{
			String loadedData = Files.lines(datapath).collect(Collectors.joining()); 
			Document doc = Jsoup.parse(loadedData);
			Element table = doc.select("table").get(5);
			Elements tableRows = table.select("tr");
			for(Element row: tableRows){
				Elements tableEntry = row.select("td");
				result.add(tableEntry.get(1).ownText());
			}
			result.remove(0); //remove the head of the table
			return result;
		}catch(IOException e){
			e.printStackTrace();
			log.debug("loading failed.");
			return result;
		}
	}
	
	public static JSONObject loadTestQuestions(){
		String loadeddata;
		try {			
			Path datapath = Paths.get("./src/main/resources/qald-6-test-multilingual.json");
			loadeddata = Files.lines(datapath).collect(Collectors.joining());
			JSONParser parser = new JSONParser();
			JSONObject arr = (JSONObject) parser.parse(loadeddata);
			return arr;
		} catch (IOException | ParseException  e) {
			e.printStackTrace();
			log.debug("loading failed.");
			return new JSONObject();
		}
	}
	
	public static <T> Set<Set<T>> powerSet(Set<T> originalSet) {
	    Set<Set<T>> sets = new HashSet<Set<T>>();
	    if (originalSet.isEmpty()) {
	    	sets.add(new HashSet<T>());
	    	return sets;
	    }
	    List<T> list = new ArrayList<T>(originalSet);
	    T head = list.get(0);
	    Set<T> rest = new HashSet<T>(list.subList(1, list.size())); 
	    for (Set<T> set : powerSet(rest)) {
	    	Set<T> newSet = new HashSet<T>();
	    	newSet.add(head);
	    	newSet.addAll(set);
	    	sets.add(newSet);
	    	sets.add(set);
	    }		
	    return sets;
	}

}