/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.mitre.quaerite.cli;

import java.io.BufferedWriter;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.Statement;
import java.text.DecimalFormat;
import java.text.DecimalFormatSymbols;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicInteger;

import org.apache.commons.lang3.StringUtils;
import org.apache.commons.math3.stat.inference.TTest;
import org.apache.log4j.Logger;
import org.mitre.quaerite.connectors.QueryRequest;
import org.mitre.quaerite.connectors.SearchClient;
import org.mitre.quaerite.connectors.SearchClientException;
import org.mitre.quaerite.connectors.SearchClientFactory;
import org.mitre.quaerite.core.Experiment;
import org.mitre.quaerite.core.ExperimentConfig;
import org.mitre.quaerite.core.ExperimentSet;
import org.mitre.quaerite.core.JudgmentList;
import org.mitre.quaerite.core.Judgments;
import org.mitre.quaerite.core.QueryInfo;
import org.mitre.quaerite.core.QueryStrings;
import org.mitre.quaerite.core.SearchResultSet;
import org.mitre.quaerite.core.queries.Query;
import org.mitre.quaerite.core.queries.TermsQuery;
import org.mitre.quaerite.core.scorers.AbstractJudgmentScorer;
import org.mitre.quaerite.core.scorers.DistributionalScoreAggregator;
import org.mitre.quaerite.core.scorers.JudgmentScorer;
import org.mitre.quaerite.core.scorers.Scorer;
import org.mitre.quaerite.core.scorers.SearchResultSetScorer;
import org.mitre.quaerite.core.scorers.SummingScoreAggregator;
import org.mitre.quaerite.core.util.MapUtil;
import org.mitre.quaerite.db.ExperimentDB;
import org.mitre.quaerite.db.QueryRunnerDBClient;

public abstract class AbstractExperimentRunner extends AbstractCLI {
    static final Judgments POISON = new Judgments(new QueryInfo("",
            "", new QueryStrings(), -1));

    static Logger LOG = Logger.getLogger(AbstractExperimentRunner.class);

    static final int DEFAULT_NUM_THREADS = 8;
    private static final int MAX_MATRIX_COLS = 100;
    //this caches a judgment list of valid judgments
    //per search server url
    Map<String, JudgmentList> searchServerValidatedMap = new HashMap<>();

    private final ExperimentConfig experimentConfig;
    NumberFormat threePlaces = new DecimalFormat(".000",
            DecimalFormatSymbols.getInstance(Locale.US));

    public AbstractExperimentRunner(ExperimentConfig experimentConfig) {
        this.experimentConfig = experimentConfig;
    }


    void runExperiment(Experiment experiment, List<Scorer> scorers,
                       int maxRows, ExperimentDB experimentDB, JudgmentList judgmentList,
                       String judgmentListId, boolean logResults)
            throws SQLException, IOException, SearchClientException {
        if (experimentDB.hasScores(experiment.getName())) {
            LOG.info("Already has scores for " + experiment.getName() + "; skipping.  " +
                    "Use the -freshStart commandline option to clear all scores");
            return;
        }
        experimentDB.initScoreTable(scorers);
        SearchClient searchClient = SearchClientFactory.getClient(experiment.getSearchServerUrl());

        if (StringUtils.isBlank(experimentConfig.getIdField())) {
            LOG.info("default document 'idField' not set in experiment config. " +
                    "Will use default: '"
                    + searchClient.getDefaultIdField() + "'");
            experimentConfig.setIdField(searchClient.getDefaultIdField());
        }

        JudgmentList validated = searchServerValidatedMap.get(
                experiment.getSearchServerUrl() +
                        "_" + judgmentListId);
        if (validated == null) {
            validated = validate(searchClient, judgmentList);
            searchServerValidatedMap.put(experiment.getSearchServerUrl()
                    + "_" + judgmentListId, validated);
        }
        ExecutorService executorService = Executors.newFixedThreadPool(
                experimentConfig.getNumThreads());
        ExecutorCompletionService<Integer> executorCompletionService =
                new ExecutorCompletionService<>(executorService);
        ArrayBlockingQueue<Judgments> queue = new ArrayBlockingQueue<>(
                validated.getJudgmentsList().size() +
                        experimentConfig.getNumThreads());

        queue.addAll(validated.getJudgmentsList());
        for (int i = 0; i < experimentConfig.getNumThreads(); i++) {
            queue.add(POISON);
        }

        for (int i = 0; i < experimentConfig.getNumThreads(); i++) {
            executorCompletionService.submit(
                    new QueryRunner(experimentConfig.getIdField(), maxRows,
                            queue, experiment, experimentDB, scorers));
        }

        int completed = 0;
        while (completed < experimentConfig.getNumThreads()) {
            try {
                Future<Integer> future = executorCompletionService.take();
                future.get();
            } catch (Exception e) {
                e.printStackTrace();
            } finally {
                completed++;
            }
        }
        executorService.shutdown();
        executorService.shutdownNow();
        //insertScores(experimentDB, experimentName, scoreAggregators);
        experimentDB.insertScoresAggregated(experiment.getName(), scorers);
        if (logResults) {
            logResults(experiment.getName(), scorers);
        }
    }

    private void logResults(String experimentName, List<Scorer> scorers) {
        StringBuilder result = new StringBuilder();
        LOG.info("Experiment: " + experimentName);
        for (Scorer scorer : scorers) {
            for (String querySetName : scorer.getQuerySets()) {
                Map<String, Double> summaryStats =
                        scorer.getSummaryStatistics(querySetName);
                if (!StringUtils.isBlank(querySetName)) {
                    result.append("Query Set: ").append(querySetName);
                } else {
                    result.append("All Queries: ");
                }
                result.append(scorer.getName());
                result.append(" - ");
                if (scorer instanceof SummingScoreAggregator) {
                    result.append("sum: ");
                    result.append(getValueString(summaryStats.get(SummingScoreAggregator.SUM)));
                } else if (scorer instanceof DistributionalScoreAggregator) {
                    result.append("mean: ");
                    result.append(
                            getValueString(summaryStats.get(DistributionalScoreAggregator.MEAN)));
                    result.append(", median: ");
                    result.append(
                            getValueString(summaryStats.get(DistributionalScoreAggregator.MEDIAN)));
                }
                LOG.info(result);
                result.setLength(0);
            }
        }
    }

    protected String getValueString(Double value) {

        if (value != null) {
            if ((long) value.doubleValue() == value) {
                return Long.toString((long) value.doubleValue());
            } else {
                return threePlaces.format(value);
            }
        } else {
            return "couldn't find value?!";
        }
    }

    /*
    private void insertScores(ExperimentDB experimentDB, String experimentName,
     List<ScoreAggregator> scoreAggregators)
            throws SQLException {
        Set<QueryInfo> queries = scoreAggregators.get(0).getScores().keySet();
        //TODO -- need to add better handling for missing queries
        Map<String, Double> tmpScores = new HashMap<>();
        for (QueryInfo queryInfo : queries) {
            tmpScores.clear();
            for (ScoreAggregator scoreAggregator : scoreAggregators) {
                double val = scoreAggregator.getScores().get(queryInfo);
                tmpScores.put(scoreAggregator.getName(), val);
            }
            experimentDB.insertScores(queryInfo, experimentName, scoreAggregators,
            tmpScores);
        }
    }

     */

    //TODO -- make this multi threaded

    /**
     * This reads through the judgment list and makes sure that the
     * a document with a given judgment's id is actually available in the
     * index.  This removes those ids that are not in the index and returns
     * a winnowed/validated {@link JudgmentList}.
     *
     * @param searchClient
     * @param judgmentList
     * @return
     */
    private JudgmentList validate(SearchClient searchClient,
                                  JudgmentList judgmentList)
            throws IOException, SearchClientException {
        String idField = searchClient.getIdField(experimentConfig);
        Set<String> judgmentIds = new HashSet<>();
        for (Judgments j : judgmentList.getJudgmentsList()) {
            judgmentIds.addAll(j.getSortedJudgments().keySet());
        }

        Set<String> valid = new HashSet<>();

        int len = 0;
        List<String> ids = new ArrayList<>();
        for (String id : judgmentIds) {
            ids.add(id);
            len += id.length();
            if (len > 1000) {
                addValid(new TermsQuery(idField, ids),
                        idField, searchClient, ids.size(), valid);
                len = 0;
                ids.clear();
            }
        }
        if (ids.size() > 0) {
            addValid(new TermsQuery(idField, ids),
                    idField, searchClient, ids.size(), valid);
        }

        int validIds = 0;
        int invalidIds = 0;
        if (judgmentIds.size() != valid.size()) {
            for (String id : judgmentIds) {
                if (!valid.contains(id)) {
                    invalidIds++;
                    LOG.warn("I regret that I could not find: " + id + " in the index. " +
                            "I'll remove this from the judgments before scoring.");
                } else {
                    validIds++;
                }
            }
        }
        if (invalidIds > 0) {
            LOG.warn("There were " + validIds + " unique valid ids and " +
                    invalidIds + " unique invalid ids");
        }
        int validQueries = 0;
        int invalidQueries = 0;
        JudgmentList retList = new JudgmentList();
        for (Judgments j : judgmentList.getJudgmentsList()) {
            //defensively copy
            Judgments winnowedJugments = new Judgments(
                    new QueryInfo(j.getQueryInfo().getQueryId(),
                            j.getQuerySet(), j.getQueryStrings(), j.getQueryCount()));
            for (Map.Entry<String, Double> e : j.getSortedJudgments().entrySet()) {
                if (valid.contains(e.getKey())) {
                    winnowedJugments.addJudgment(e.getKey(), e.getValue());
                } else {
                    LOG.warn("Could not find " + e.getKey() + " in the index!");
                }
            }
            if (winnowedJugments.getSortedJudgments().size() > 0) {
                retList.addJudgments(winnowedJugments);
                validQueries++;
            } else {
                LOG.warn(
                        "After removing invalid jugments, there were 0 " +
                                "judgments for query: " +
                                j.getQueryInfo().getQueryId());
                invalidQueries++;
            }
        }
        if (invalidQueries > 0) {
            LOG.warn("I had to remove " + invalidQueries +
                    " queries because there were no judgments for them. " +
                    " There were " + validQueries + " valid queries.");
        }
        return retList;

    }

    private static void addValid(TermsQuery termsQuery, String idField,
                                 SearchClient searchClient, int expected,
                                 Set<String> valid) {
        if (expected == 0) {
            return;
        }
        QueryRequest q = new QueryRequest(termsQuery, null, idField);
        q.addFieldsToRetrieve(idField);
        q.setNumResults(expected * 2);
        SearchResultSet searchResultSet;
        try {
            searchResultSet = searchClient.search(q);
        } catch (SearchClientException | IOException e) {
            throw new RuntimeException(e);
        }
        Set<String> localValid = new HashSet<>();
        for (int i = 0; i < searchResultSet.size(); i++) {
            String id = searchResultSet.get(i);
            if (localValid.contains(id)) {
                LOG.warn("Found non-unique key: " + id);
            }
            valid.add(id);
        }

    }


    static class QueryRunner implements Callable<Integer> {
        private static AtomicInteger IDs = new AtomicInteger();
        private final int threadNum = IDs.getAndIncrement();
        private final String idField;
        private final int maxRows;
        private final ArrayBlockingQueue<Judgments> queue;
        private final Experiment experiment;
        private final Query query;//thread safe clone of the query
        private final List<Scorer> scorers;
        private final SearchClient searchClient;
        private final QueryRunnerDBClient dbClient;
        private int batched = 0;

        public QueryRunner(String idField, int maxRows, ArrayBlockingQueue<Judgments> judgments,
                           Experiment experiment, ExperimentDB experimentDB,
                           List<Scorer> scorers) throws SQLException, IOException, SearchClientException {
            this.idField = idField;
            this.maxRows = maxRows;
            this.queue = judgments;
            this.experiment = experiment;
            this.query = experiment.getQuery();
            this.searchClient = SearchClientFactory.getClient(experiment.getSearchServerUrl());
            this.scorers = scorers;
            this.dbClient = experimentDB.getQueryRunnerDBClient(scorers);
        }

        @Override
        public Integer call() throws Exception {

            try {
                while (true) {
                    Judgments judgments = queue.poll();
                    if (judgments.equals(POISON)) {
//                    LOG.trace(threadNum + ": scorer thread hit poison. stopping now");
                        return 1;
                    }
                    scoreEach(judgments, scorers);
                    if (batched++ > 100) {
                        batched = 0;
                        dbClient.executeBatch();
                    }
                }
            } finally {
                Exception ex = null;
                try {
                    dbClient.close();
                } catch (Exception e) {
                    ex = e;
                }
                searchClient.close();
                if (ex != null) {
                    throw ex;
                }
            }
        }

        private void scoreEach(Judgments judgments,
                               List<Scorer> scorers) throws SQLException {
            query.setQueryStrings(judgments.getQueryStrings());

            QueryRequest queryRequest = new QueryRequest(query, experiment.getCustomHandler(), idField);
            queryRequest.addFieldsToRetrieve(idField);
            if (experiment.getFilterQueries().size() > 0) {
                queryRequest.addFilterQueries(experiment.getFilterQueries());
            }
            queryRequest.setNumResults(maxRows);

            SearchResultSet searchResultSet = null;
            try {
                searchResultSet = searchClient.search(queryRequest);
            } catch (SearchClientException | IOException e) {
                //TODO add exception to searchResultSet and log
                e.printStackTrace();
            }
            dbClient.insertSearchResults(judgments.getQueryInfo(),
                    experiment.getName(), searchResultSet);

            for (Scorer scorer : scorers) {
                if (scorer instanceof JudgmentScorer) {
                    ((JudgmentScorer) scorer).score(judgments, searchResultSet);
                } else if (scorer instanceof SearchResultSetScorer) {
                    ((SearchResultSetScorer) scorer).score(judgments.getQueryInfo(),
                            searchResultSet);
                } else {
                    throw new IllegalArgumentException("Scorer class not yet supported: "
                            + scorer.getClass());
                }
            }
            dbClient.insertScores(judgments.getQueryInfo(), experiment.getName(), scorers);
        }
    }


    ////////////DUMP RESULTS
    static void dumpResults(ExperimentSet experimentSet, ExperimentDB experimentDB,
                            List<String> querySets,
                            List<Scorer> scorers, Path outputDir, boolean isTest) throws Exception {
        if (!Files.isDirectory(outputDir)) {
            Files.createDirectories(outputDir);
        }
        dumpPerQuery(experimentDB, outputDir);

        String orderByPriority1 = null;
        String orderByPriority2 = null;
        for (Scorer scorer : experimentSet.getScorers()) {
            if (isTest && scorer instanceof AbstractJudgmentScorer &&
                    ((AbstractJudgmentScorer) scorer).getUseForTest()) {
                orderByPriority1 = scorer.getPrimaryStatisticName();
                break;
            }
            if (scorer instanceof AbstractJudgmentScorer &&
                    ((AbstractJudgmentScorer) scorer).getUseForTrain()) {
                orderByPriority2 = scorer.getPrimaryStatisticName();
            }
        }
        String orderBy = "";
        if (orderByPriority1 != null) {
            orderBy = " order by " + orderByPriority1 + " desc";
        } else if (orderByPriority1 == null && orderByPriority2 != null) {
            orderBy = " order by " + orderByPriority2 + " desc";
        }
        try (BufferedWriter writer = Files.newBufferedWriter(
                outputDir.resolve("scores_aggregated.csv"), StandardCharsets.UTF_8)) {
            try (Statement st = experimentDB.getConnection().createStatement()) {
                try (java.sql.ResultSet resultSet =
                             st.executeQuery("select * from SCORES_AGGREGATED "
                                     + orderBy)) {
                    writeHeaders(resultSet.getMetaData(), writer);
                    while (resultSet.next()) {
                        writeRow(resultSet, writer);
                    }
                }
                writer.flush();
            }
        }
        if (querySets.size() > 0) {
            for (String querySet : querySets) {
                dumpSignificanceMatrices(querySet, scorers, experimentDB, outputDir);
            }
        }
        //now dump across all query sets
        dumpSignificanceMatrices("", scorers, experimentDB, outputDir);


    }

    private static void dumpPerQuery(ExperimentDB experimentDB, Path outputDir) throws Exception {
        StringBuilder select = new StringBuilder();
        select.append("select " +
                "s.query_id QUERY_ID, " +
                "QUERY_NAME, " +
                "s.query_set QUERY_SET, " +
                "s.query_count QUERY_COUNT, " +
                "EXPERIMENT");
        for (String scorer : experimentDB.getScoreAggregatorNames()) {
            select.append(", ").append(scorer);
        }
        select.append(" from SCORES s");
        select.append(" join judgments j on s.query_id=j.query_id");
        if (experimentDB.hasNamedQuerySets()) {
            select.append(" where s.QUERY_SET <> ''");
        }
        select.append(" order by experiment, s.query_set, query_name");
        try (BufferedWriter writer = Files.newBufferedWriter(
                outputDir.resolve("per_query_scores.csv"), StandardCharsets.UTF_8)) {
            try (Statement st = experimentDB.getConnection().createStatement()) {

                try (java.sql.ResultSet resultSet = st.executeQuery(select.toString())) {
                    writeHeaders(resultSet.getMetaData(), writer);
                    while (resultSet.next()) {
                        writeRow(resultSet, writer);
                    }
                }
                writer.flush();
            }
        }
    }

    private static void dumpSignificanceMatrices(String querySet,
                                                 List<Scorer> targetScorers,
                                                 ExperimentDB experimentDB,
                                                 Path outputDir) throws Exception {
        TTest tTest = new TTest();
        for (Scorer scorer : targetScorers) {
            if (scorer instanceof AbstractJudgmentScorer &&
                    ((AbstractJudgmentScorer) scorer).getExportPMatrix()) {
                Map<String, Double> aggregatedScores =
                        experimentDB.getKeyExperimentScore(scorer, querySet);

                Map<String, Double> sorted = MapUtil.sortByDescendingValue(aggregatedScores);
                List<String> experiments = new ArrayList();
                experiments.addAll(sorted.keySet());
                writeMatrix(tTest, (AbstractJudgmentScorer) scorer,
                        querySet, experiments, experimentDB, outputDir);
            }
        }
    }

    private static void writeMatrix(TTest tTest, AbstractJudgmentScorer scorer,
                                    String querySet,
                                    List<String> experiments,
                                    ExperimentDB experimentDB,
                                    Path outputDir) throws Exception {

        String fileName = "sig_diffs_" + scorer.getName() + (
                (StringUtils.isBlank(querySet)) ? ".csv" : "_" + querySet + ".csv");

        List<String> matrixExperiments = new ArrayList<>();
        for (int i = 0; i < experiments.size() && i < MAX_MATRIX_COLS; i++) {
            matrixExperiments.add(experiments.get(i));
        }
        try (BufferedWriter writer = Files.newBufferedWriter(outputDir.resolve(fileName))) {

            for (String experiment : matrixExperiments) {
                writer.write(",");
                writer.write(experiment);
            }
            writer.write("\n");

            for (int i = 0; i < matrixExperiments.size(); i++) {
                String experimentA = matrixExperiments.get(i);
                writer.write(experimentA);
                for (int k = 0; k <= i; k++) {
                    writer.write(",");
                }
                writer.write(String.format(Locale.US, "%.3G", 1.0d) + ",");//p-value of itself
                //map of query -> score for experiment A given this particular scorer
                Map<String, Double> scoresA = experimentDB.getScores(querySet,
                        experimentA, scorer.getName());
                for (int j = i + 1; j < matrixExperiments.size(); j++) {
                    String experimentB = matrixExperiments.get(j);
                    double significance =
                            calcSignificance(tTest, querySet, scoresA,
                                    experimentA, experimentB,
                            scorer.getName(), experimentDB);
                    writer.write(String.format(Locale.US, "%.3G", significance));
                    writer.write(",");
                }
                writer.write("\n");
            }
        }
    }

    private static double calcSignificance(TTest tTest, String querySet,
                                           Map<String, Double> scoresA, String experimentA,
                                           String experimentB, String scorer,
                                           ExperimentDB experimentDB) throws SQLException {

        Map<String, Double> scoresB = experimentDB.getScores(querySet, experimentB, scorer);
        if (scoresA.size() != scoresB.size()) {
            //log
            System.err.println("Different number of scores for " +
                    experimentA + "(" + scoresA.size() +
                    ") vs. " + experimentB + "(" + scoresB.size() + ")");
        }
        double[] arrA = new double[scoresA.size()];
        double[] arrB = new double[scoresB.size()];

        int i = 0;
        for (String query : scoresA.keySet()) {
            Double scoreA = scoresA.get(query);
            Double scoreB = scoresB.get(query);
            if (scoreA == null || scoreA < 0) {
                scoreA = 0.0d;
            }
            if (scoreB == null || scoreB < 0) {
                scoreB = 0.0d;
            }
            arrA[i] = scoreA;
            arrB[i] = scoreB;
            i++;
        }
//        WilcoxonSignedRankTest w = new WilcoxonSignedRankTest();
        //      w.wilcoxonSignedRankTest()
        if (arrA.length < 2) {
            LOG.warn("too few examples for t-test; returning -1");
            return -1;
        }
        return tTest.tTest(arrA, arrB);

    }

    private static void writeHeaders(ResultSetMetaData metaData, BufferedWriter writer)
            throws Exception {
        for (int i = 1; i <= metaData.getColumnCount(); i++) {
            writer.write(clean(metaData.getColumnName(i)));
            writer.write(",");
        }
        writer.write("\n");
    }

    private static void writeRow(java.sql.ResultSet resultSet, BufferedWriter writer)
            throws Exception {
        for (int i = 1; i <= resultSet.getMetaData().getColumnCount(); i++) {
            writer.write(clean(resultSet.getString(i)));
            writer.write(",");
        }
        writer.write("\n");
    }

    private static String clean(String string) {
        if (string == null) {
            return "";
        }
        string = string.replaceAll("[\r\n]", " ");
        if (string.contains(",")) {
            string.replaceAll("\"", "\"\"");
            string = "\"" + string + "\"";
        }
        return string;
    }
}