package edu.stanford.nlp.mt.wordcls;

import java.io.IOException;
import java.io.LineNumberReader;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import edu.stanford.nlp.mt.util.IOTools;
import edu.stanford.nlp.mt.util.IString;
import edu.stanford.nlp.mt.util.IStrings;
import edu.stanford.nlp.mt.util.Sequence;
import edu.stanford.nlp.mt.util.TokenUtils;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.stats.TwoDimensionalCounter;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.PropertiesUtils;
import edu.stanford.nlp.util.StringUtils;
import edu.stanford.nlp.util.concurrent.MulticoreWrapper;
import edu.stanford.nlp.util.concurrent.ThreadsafeProcessor;

/**
 * Various algorithms for learning a mapping function from an input
 * word to an output equivalence class.
 * 
 * TODO Extract out objective function as an interface to support
 * other clustering algorithms if needed.
 * 
 * @author Spence Green
 *
 */
public class MakeWordClasses {

  private static final Logger logger = LogManager.getLogger(MakeWordClasses.class);
  
  private final int numIterations;
  private final int numClasses;
  private final int numThreads;
  private final int vparts;
  private final int order;
  private final String inputEncoding;

  private static enum OutputFormat {SRILM, TSV};

  private static final int INITIAL_CAPACITY = 100000;
  private final Map<IString,Integer> wordToClass;
  private final Counter<IString> wordCount;
  private final TwoDimensionalCounter<IString, NgramHistory> historyCount;
  private TwoDimensionalCounter<Integer,NgramHistory> classHistoryCount;
  private final ClassicCounter<Integer> classCount;
  private final OutputFormat outputFormat;
  private final int vocabThreshold;
  private List<IString> effectiveVocabulary;
  private final boolean normalizeDigits;
  private final boolean writeUnkClass;

  private double currentObjectiveValue = 0.0;
  
  public MakeWordClasses(Properties properties) {
    // User options
    this.numIterations = PropertiesUtils.getInt(properties, "niters", 30);
    assert this.numIterations > 0;

    this.numClasses = PropertiesUtils.getInt(properties, "nclasses", 512);
    assert this.numClasses > 0;

    this.numThreads = PropertiesUtils.getInt(properties, "nthreads", 1);
    assert this.numThreads > 0;

    this.vparts = PropertiesUtils.getInt(properties, "vparts", 3);
    assert this.vparts > 0;

    this.order = PropertiesUtils.getInt(properties, "order", 2);
    assert this.order > 1;

    this.vocabThreshold = PropertiesUtils.getInt(properties, "vclip", 5);
    assert this.vocabThreshold >=0;
    
    this.inputEncoding = properties.getProperty("encoding", IOTools.DEFAULT_ENCODING);

    this.normalizeDigits = PropertiesUtils.getBool(properties, "normdigits", true);
    
    this.writeUnkClass = PropertiesUtils.getBool(properties, "writeunk", false);

    this.outputFormat = OutputFormat.valueOf(
        properties.getProperty("format", OutputFormat.TSV.toString()).toUpperCase());

    logger.info("#iterations: {}", numIterations);
    logger.info("#classes: {}", numClasses);
    logger.info("order: {}", order);
    logger.info("#vocabulary partitions: {}", vparts);
    logger.info("Rare word threshold: {}", vocabThreshold);
    logger.info("Input file encoding: {}", inputEncoding);
    if (normalizeDigits) {
      logger.info("Mapping all ASCII digit characters to 0");
    }

    // Internal data structures
    wordToClass = new HashMap<>(INITIAL_CAPACITY);
    wordCount = new ClassicCounter<IString>(INITIAL_CAPACITY);
    classCount = new ClassicCounter<Integer>(numClasses);
    historyCount = new TwoDimensionalCounter<IString,NgramHistory>();
    classHistoryCount = new TwoDimensionalCounter<Integer,NgramHistory>();
  }

  /**
   * Read the input and create the initial clustering.
   * 
   * @param filenames
   * @throws IOException
   */
  private void initialize(String[] filenames) throws IOException {
    List<IString> defaultHistory = new ArrayList<>();
    for (int i = 0; i < order-1; ++i) {
      defaultHistory.add(TokenUtils.START_TOKEN);
    }
    
    // Read the vocabulary and histories
    final long startTime = System.nanoTime();
    for (String filename : filenames) {
      logger.info("Reading: " + filename);
      LineNumberReader reader = IOTools.getReaderFromFile(filename, inputEncoding);
      for (String line; (line = reader.readLine()) != null;) {
        line = line.trim();
        if (line.length() == 0) continue;
        Sequence<IString> tokens = IStrings.tokenize(line);
        List<IString> history = new ArrayList<>(defaultHistory);
        for (IString token : tokens) {
          if (normalizeDigits && TokenUtils.hasDigit(token.toString())) {
            token = new IString(TokenUtils.normalizeDigits(token.toString()));
          }
          wordCount.incrementCount(token);
          historyCount.incrementCount(token, new NgramHistory(history));

          // Update the ngram history
          history.add(token);
          history.remove(0);
        }
      }
      reader.close();
    }
    NgramHistory.lockIndex();
    final double elapsedTime = ((double) System.nanoTime() - startTime) / 1e9;
    logger.info(String.format("Done reading input files (%.3fsec)", elapsedTime));
    logger.info(String.format("Input gross statistics: %d words  %d tokens  %d histories", 
        wordCount.keySet().size(), (int) wordCount.totalCount(), (int) historyCount.totalCount()));

    // Collapse vocabulary by mapping rare words to <unk>
    Set<IString> fullVocabulary = new HashSet<>(wordCount.keySet());
    Set<IString> filteredWords = new HashSet<>(fullVocabulary.size());
    for (IString word : fullVocabulary) {
      int count = (int) wordCount.getCount(word);
      if (vocabThreshold > 0 && count < vocabThreshold) {
        filteredWords.add(word);
        wordCount.incrementCount(TokenUtils.UNK_TOKEN, count);
        wordCount.remove(word);
        Counter<NgramHistory> histories = historyCount.getCounter(word);
        Counter<NgramHistory> unkHistories = historyCount.getCounter(TokenUtils.UNK_TOKEN);
        Counters.addInPlace(unkHistories, histories);
        historyCount.remove(word);
        if(writeUnkClass) System.out.printf("%s\t%d%n", word.toString(), numClasses);       
      }
    }

    // Setup the vocabulary that will be clustered (i.e., the
    // effective vocabulary)
    if (filteredWords.size() > 0) {
      logger.info(String.format("Mapping %d / %d words to unk token %s", 
          filteredWords.size(), fullVocabulary.size(), TokenUtils.UNK_TOKEN.toString()));
      fullVocabulary.add(TokenUtils.UNK_TOKEN);
    }
    fullVocabulary.removeAll(filteredWords);
    effectiveVocabulary = new ArrayList<>(fullVocabulary);

    // Initialize clustering
    Collections.sort(effectiveVocabulary, Counters.toComparator(wordCount, false, true));
    for (int i = 0; i < effectiveVocabulary.size(); ++i) {
      IString word = effectiveVocabulary.get(i);
      int classId = i % numClasses;
      classCount.incrementCount(classId, wordCount.getCount(word));
      wordToClass.put(word, classId);
      Counter<NgramHistory> historiesForWord = historyCount.getCounter(word);
      Counter<NgramHistory> historiesForClass = classHistoryCount.getCounter(classId);
      Counters.addInPlace(historiesForClass, historiesForWord);
    }
    Collections.shuffle(effectiveVocabulary);

    // Debug output
    logger.info("Effective vocabulary size: " + String.valueOf(effectiveVocabulary.size()));
    currentObjectiveValue = objectiveFunctionValue();
    logger.info("Finished generating initial cluster assignment");
    logger.info(String.format("Initial objective function value: %.3f%n", currentObjectiveValue));
  }

  /**
   * Create word clusters from the list of input files.
   * 
   * @param filenames
   */
  public void run(String[] filenames) {
    final long runStartTime = System.nanoTime();
    try {
      initialize(filenames);
    } catch (IOException e1) {
      throw new RuntimeException(e1);
    }

    logger.info(String.format("Starting clustering with %d threads", numThreads));
    for (int e = 0; e < numIterations; ++e) {
      MulticoreWrapper<ClustererState,PartialStateUpdate> threadpool = 
          new MulticoreWrapper<ClustererState,PartialStateUpdate>(numThreads, 
              new ThreadsafeProcessor<ClustererState,PartialStateUpdate>() {
            @Override
            public PartialStateUpdate process(ClustererState input) {
              OneSidedObjectiveFunction algorithm = new OneSidedObjectiveFunction(input);
              return algorithm.cluster();
            }
            @Override
            public ThreadsafeProcessor<ClustererState, PartialStateUpdate> newInstance() {
              return this;
            }
          });
      
      // Select vocabulary partition number
      final int partitionNumber = e % vparts;

      if (e > 0 && partitionNumber == 0) {
        logger.info("Sorting vocabulary according to the current class assignments");
        Collections.sort(effectiveVocabulary, new WordClassComparator(wordToClass));
      }

      logger.info(String.format("Iteration %d: partition %d start", e, partitionNumber));
      final long iterationStartTime = System.nanoTime();
      int startIndex = 0;
      for (int t = 0; t < numThreads; ++t) {
        Pair<ClustererState,Integer> input = createInput(partitionNumber, t, startIndex);
        if (input != null) {
          threadpool.put(input.first());
          startIndex = input.second();
        }
      }

      // Wait for shutdown and process results
      threadpool.join();
      int numUpdates = 0;
      while(threadpool.peek()) {
        PartialStateUpdate result = threadpool.poll();
        numUpdates += updateCountsWith(result);
      }

      // Clean out zeros from counters after updating
      classHistoryCount.clean();
      Counters.retainNonZeros(classCount);

      double elapsedTime = ((double) System.nanoTime() - iterationStartTime) / 1e9;
      logger.info(String.format("Iteration %d: elapsed time %.3fsec", e, elapsedTime));
      logger.info(String.format("Iteration %d: #updates %d", e, numUpdates));
      logger.info(String.format("Iteration %d: objective: %.4f", e, objectiveFunctionValue()));
    }

    double elapsedTime = ((double) System.nanoTime() - runStartTime) / 1e9;
    logger.info(String.format("Total runtime: %.3fsec", elapsedTime));
  }

  private static class WordClassComparator implements Comparator<IString> {
    Map<IString, Integer> map;
    public WordClassComparator(Map<IString, Integer> map) {
      this.map = map;
    }
    public int compare(IString a, IString b) {
      int classA = map.get(a);
      int classB = map.get(b);
      return (int) Math.signum(classA - classB);
    }
  }
  
  /**
   * Objective function of Uszkoreit and Brants (2008) (Eq. 10).
   * 
   * @return
   */
  private double objectiveFunctionValue() {    
    double objValue = 0.0;
    for (int classId = 0; classId < numClasses; ++classId) {
      Counter<NgramHistory> historyCount = classHistoryCount.getCounter(classId);
      for (NgramHistory history : historyCount.keySet()) {
        double count = historyCount.getCount(history);
        assert count > 0.0;
        objValue += count * Math.log(count);
      }
      double count = classCount.getCount(classId);
      if (count > 0.0) {
        objValue -= count * Math.log(count);
      } else {
        logger.warn("Empty cluster: {}", classId);
      }
    }
    return objValue;
  }

  /**
   * Create the input to a clustering iteration.
   * 
   * @param fullVocabulary
   * @param partitionNumber
   * @param threadId
   * @return
   */
  private Pair<ClustererState,Integer> createInput(int partitionNumber, int threadId, int inputStart) {
    int partitionSize = effectiveVocabulary.size() / vparts;
    int partitionStart = partitionNumber*partitionSize;
    int partitionEnd = partitionNumber == vparts-1 ? effectiveVocabulary.size() : (partitionNumber+1)*partitionSize;
    partitionSize = partitionEnd-partitionStart;

    int targetInputSize = partitionSize / numThreads;
    int startIndex = inputStart == 0 ? partitionStart + inputStart : inputStart;
    int endIndex = Math.min(partitionEnd, startIndex + targetInputSize);
    if (endIndex - startIndex <= 0) return null;
    
    // Brants and Uszkoreit heuristic: make sure that all words from a given class
    // end up in the same worker.
    int i = endIndex-1;
    for (; i < partitionEnd-1; ++i) {
      IString iWord = effectiveVocabulary.get(i);
      IString nextWord = effectiveVocabulary.get(i+1);
      int iClass= wordToClass.get(iWord);
      int nextClass= wordToClass.get(nextWord);
      if (iClass != nextClass) {
        break;
      }
    }
    logger.info(String.format("endIndex: %d -> %d", endIndex, i+1));
    endIndex = i+1;

    List<IString> inputVocab = effectiveVocabulary.subList(startIndex, endIndex);
    
    logger.info(String.format("Partition %d thread %d size %d: input %d-%d", partitionNumber,
        threadId, inputVocab.size(), startIndex, endIndex-1));
    
    // Create the state
    ClustererState state =  new ClustererState(inputVocab, this.wordCount, 
        this.historyCount, this.wordToClass, this.classCount, this.classHistoryCount,
        numClasses, this.currentObjectiveValue);
    return new Pair<ClustererState,Integer>(state, endIndex);
  }

  private int updateCountsWith(PartialStateUpdate result) {
    // Update counts
    Counters.addInPlace(classCount, result.deltaClassCount);
    Set<Integer> classes = result.deltaClassHistoryCount.firstKeySet();
    for (Integer classId : classes) {
      Counter<NgramHistory> counter = this.classHistoryCount.getCounter(classId);
      Counter<NgramHistory> delta = result.deltaClassHistoryCount.getCounter(classId);
      Counters.addInPlace(counter, delta);
    }

    // Update assignments
    int numUpdates = 0;
    for (Map.Entry<IString, Integer> assignment : result.wordToClass.entrySet()) {
      int oldAssignment = wordToClass.get(assignment.getKey());
      int newAssignment = assignment.getValue();
      if (oldAssignment != newAssignment) {
        ++numUpdates;
        wordToClass.put(assignment.getKey(), assignment.getValue());
      }
    }
    return numUpdates;
  }
  
  /**
   * Write the final cluster assignments to the specified output stream.
   * 
   * @param out
   */
  public void writeResults(PrintStream out) {
    logger.info(String.format("Writing final class assignments in %s format",
        outputFormat.toString()));
    Collections.sort(effectiveVocabulary, new WordClassComparator(wordToClass));
    for (IString word : effectiveVocabulary) {
      int assignment = wordToClass.get(word);
      if (outputFormat == OutputFormat.TSV) {
        out.printf("%s\t%d%n", word.toString(), assignment);

      } else if (outputFormat == OutputFormat.SRILM) {
        out.printf("%d 1.0 %s%n", assignment, word.toString());
      }
    }
  }

  private static Map<String, Integer> optionArgDefs() {
    Map<String,Integer> argDefs = new HashMap<>();
    argDefs.put("order", 1);
    argDefs.put("nthreads", 1);
    argDefs.put("nclasses", 1);
    argDefs.put("niters", 1);
    argDefs.put("vparts", 1);
    argDefs.put("format", 1);
    argDefs.put("name", 1);
    argDefs.put("vclip", 1);
    argDefs.put("normdigits", 1);
    argDefs.put("encoding", 1);
    return argDefs;
  }

  private static String usage() {
    StringBuilder sb = new StringBuilder();
    final String nl = System.getProperty("line.separator");
    sb.append("Usage: java ").append(MakeWordClasses.class.getName()).append(" OPTS file [file] > output").append(nl)
    .append(" -order num       : Model order (default: 2)").append(nl)
    .append(" -nthreads num    : Number of threads (default: 1)").append(nl)
    .append(" -nclasses num    : Number of classes (default: 512)").append(nl)
    .append(" -niters num      : Number of iterations (default: 30)").append(nl)
    .append(" -vparts num      : Number of vocabulary partitions (default: 3)").append(nl)
    .append(" -format type     : Output format [srilm|tsv] (default: tsv)").append(nl)
    .append(" -name str        : Run name for log file.").append(nl)
    .append(" -vclip num       : Map rare words to <unk> (default: 5)").append(nl)
    .append(" -normdigits bool : Map ASCII digits to 0 (default: true)").append(nl)
    .append(" -encoding str    : Input file encoding (default: UTF-8)");

    return sb.toString();
  }

  /**
   * @param args
   */
  public static void main(String[] args) {
    Properties options = StringUtils.argsToProperties(args, optionArgDefs());
    String[] filenames = options.getProperty("","").split("\\s+");
    if (filenames.length < 1 || filenames[0].length() == 0 || options.containsKey("h")
        || options.containsKey("help")) {
      System.err.println(usage());
      System.exit(-1);
    }
    MakeWordClasses mkWordCls = new MakeWordClasses(options);
    mkWordCls.run(filenames);
    mkWordCls.writeResults(System.out);
  }
}