package de.jungblut.conll;

import com.google.common.collect.HashMultiset;
import com.google.common.collect.Multiset;
import info.debatty.java.stringsimilarity.QGram;
import org.apache.commons.cli.*;
import org.yaml.snakeyaml.Yaml;

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.*;
import java.util.regex.Pattern;

import static com.google.common.base.Preconditions.checkArgument;

public class VectorizerMain {

    private static final Pattern SPLIT_PATTERN = Pattern.compile(" ");
    private static final int SEQUENCE_LEN = 20;
    private static final int CHAR_NGRAM_LEN = 3;
    private static final int MIN_CHAR_NGRAM_OCCURRENCE = 200;

    private static final int EMBEDDING_VECTOR_SIZE = 50;
    private static final int POS_TAG_SIZE = 47;
    private static final int SHAPE_FEATURES_SIZE = 5;

    private static final String DATA_PATH = "data/";
    private static final String GLOVE_FILE_NAME = "glove.6B.50d.txt";
    private static final String NER_TRAIN_FILE_NAME = "eng.train.txt";
    private static final String TRAIN_OUT_FILE_NAME = "vectorized";
    private static final String META_OUT_FILE_NAME = "meta.yaml";
    private static final String OUT_LABEL = "O";

    private LabelManager labelManager;
    private LabelManager posTagManager;
    private String[] qgramDict;
    private int sequenceLength;
    private int embeddingVectorSize;
    private String embeddingPath;
    private String inputFilePath;
    private String outputFolder;
    private boolean binaryOutput;

    private VectorizerMain(int sequenceLength,
                           int embeddingVectorSize,
                           String embeddingPath,
                           String inputFilePath,
                           String outputFolder,
                           boolean binaryOutput,
                           LabelManager labelManager,
                           LabelManager posTagManager,
                           String[] qgramDict) {
        this.sequenceLength = sequenceLength;
        this.embeddingVectorSize = embeddingVectorSize;
        this.embeddingPath = embeddingPath;
        this.inputFilePath = inputFilePath;
        this.outputFolder = outputFolder;
        this.binaryOutput = binaryOutput;
        this.labelManager = labelManager;
        this.posTagManager = posTagManager;
        this.qgramDict = qgramDict;
    }

    private void vectorize() throws IOException {
        System.out.println("Sequence length: " + sequenceLength);
        System.out.println("Embedding vector dimension: " + embeddingVectorSize);
        System.out.println("Embedding path: " + embeddingPath);
        System.out.println("Input path: " + inputFilePath);
        System.out.println("Binary output: " + binaryOutput);
        System.out.println("Output folder: " + outputFolder);

        // read the glove embeddings
        HashMap<String, float[]> embeddingMap = readGloveEmbeddings(embeddingPath, embeddingVectorSize);
        System.out.println("read " + embeddingMap.size()
                + " embedding vectors. Vectorizing...");

        labelManager.getOrCreate(OUT_LABEL);
        final QGram qgram = new QGram(CHAR_NGRAM_LEN);
        final String[] dict = qgramDict != null ? qgramDict : prepareNGramDictionary(qgram);
        System.out.println("qgram dictionary len: " + dict.length);

        final int singleFeatureSize = POS_TAG_SIZE + embeddingVectorSize + SHAPE_FEATURES_SIZE + dict.length;
        final int numFinalFeatures = sequenceLength * singleFeatureSize;

        System.out.println("word feature vector size: " + singleFeatureSize);

        // use an array of zeros as the default feature
        final float[] defaultVector = new float[singleFeatureSize];

        int numTotalTokens = 0;
        int numOutOfVocabTokens = 0;
        HashMultiset<String> labelHistogram = HashMultiset.create();

        try (SequenceFileWriter writer = createWriter(outputFolder, binaryOutput)) {

            Deque<float[]> vectorBuffer = new LinkedList<>();
            Deque<Integer> labelBuffer = new LinkedList<>();

            try (BufferedReader reader = new BufferedReader(new FileReader(
                    inputFilePath))) {

                String line;
                while ((line = reader.readLine()) != null) {
                    if (line.isEmpty()) {
                        continue;
                    }

                    final float[] featureVector = new float[singleFeatureSize];
                    // format is as follows: "German JJ I-NP I-MISC"
                    String[] split = SPLIT_PATTERN.split(line);
                    String tkn = cleanToken(split[0]);
                    numTotalTokens++;
                    if (embeddingMap.containsKey(tkn)) {
                        float[] embedding = embeddingMap.get(tkn);
                        System.arraycopy(embedding, 0, featureVector, 0, embedding.length);
                    } else {
                        numOutOfVocabTokens++;
                    }

                    // we add one hot encoded pos tags into the feature vector
                    String posTag = split[1].toLowerCase().trim();
                    int posTagIndex = posTagManager.getOrCreate(posTag);
                    featureVector[embeddingVectorSize + posTagIndex] = 1f;

                    // we add shape features from the non-normalized token
                    wordShape(split[0], embeddingVectorSize + POS_TAG_SIZE, featureVector);

                    // we add qgram statistics as a one-hot encoding into the feature vector
                    Map<String, Integer> profile = qgram.getProfile(tkn);
                    for (Map.Entry<String, Integer> entry : profile.entrySet()) {
                        int i = Arrays.binarySearch(dict, entry.getKey());
                        if (i >= 0) {
                            featureVector[embeddingVectorSize + POS_TAG_SIZE + SHAPE_FEATURES_SIZE + i] += entry.getValue();
                        }
                    }

                    String label = split[3].trim();
                    labelHistogram.add(label);
                    int labelIndex = labelManager.getOrCreate(label);

                    vectorBuffer.addLast(featureVector);
                    labelBuffer.addLast(labelIndex);

                    // if we reach the buffer size we can flush the next item in the queue
                    if (vectorBuffer.size() == sequenceLength) {
                        writeAndFillSequenceIfNeeded(defaultVector, writer, labelBuffer, vectorBuffer);
                    }
                }
            }

            while (!labelBuffer.isEmpty()) {
                writeAndFillSequenceIfNeeded(defaultVector, writer, labelBuffer, vectorBuffer);
            }
        }

        System.out.println(labelHistogram);

        System.out.println("oov tokens vs. total number of tokens "
                + numOutOfVocabTokens
                + " / " + numTotalTokens
                + " = " + (numOutOfVocabTokens / (double) numTotalTokens) * 100d + "%");

        // dump the label map with # features as YAML map.
        Map<String, Object> data = new HashMap<>();
        data.put("embedding_dim", embeddingVectorSize);
        data.put("seq_len", sequenceLength);
        data.put("nlabels", labelManager.getLabelMap().size());
        data.put("feature_dim", singleFeatureSize);
        data.put("total_feature_dim", numFinalFeatures);
        // inverse the map so we can do int->string lookups somewhere else
        data.put("labels", labelManager.getLabelMap().inverse());
        data.put("pos_tags", posTagManager.getLabelMap().inverse());
        data.put("ngram_dict", dict);
        Yaml yaml = new Yaml();
        Files.write(Paths.get(outputFolder + META_OUT_FILE_NAME), yaml.dump(data)
                .getBytes());

        System.out.println("Done.");
    }

    private String[] prepareNGramDictionary(QGram qgram) throws IOException {
        final HashMultiset<String> set = HashMultiset.create();
        try (BufferedReader reader = new BufferedReader(new FileReader(
                inputFilePath))) {

            String line;
            while ((line = reader.readLine()) != null) {
                if (line.isEmpty()) {
                    continue;
                }

                String[] split = SPLIT_PATTERN.split(line);
                String tkn = cleanToken(split[0]);
                Map<String, Integer> profile = qgram.getProfile(tkn);
                for (Map.Entry<String, Integer> entry : profile.entrySet()) {
                    //noinspection ResultOfMethodCallIgnored
                    set.add(entry.getKey(), entry.getValue());
                }
            }
        }

        // do some naive word statistics cut-off
        return set.entrySet()
                .stream()
                .filter(e -> e.getCount() > MIN_CHAR_NGRAM_OCCURRENCE)
                .map(Multiset.Entry::getElement)
                .sorted()
                .toArray(String[]::new);
    }

    private String cleanToken(String s) {
        return s.toLowerCase().trim();
    }

    public static void main(String[] args) throws IOException, ParseException {

        Options options = new Options();
        options
                .addOption(
                        "s",
                        "sequenceLength",
                        true,
                        "how long the sequence should be chunked onto");
        options.addOption("d", "embvecdim", true,
                "the dimensionality of the embedding vectors");
        options.addOption("b", "binary", false,
                "if supplied, outputs in binary instead of text format");
        options.addOption("i", "input", true, "the path of the dataset");
        options.addOption("o", "output", true, "the folder for the output");
        options.addOption("e", "embeddings", true, "the path of the embeddings");
        options.addOption("l", "meta", true,
                "the path of the train meta yaml to get the labels");

        if (args.length > 0 && args[0].equals("-h")) {
            HelpFormatter formatter = new HelpFormatter();
            formatter.printHelp("vectorizer", options);
            System.exit(0);
        }

        System.out.println("add -h for more options!");
        CommandLineParser parser = new DefaultParser();
        CommandLine cmd = parser.parse(options, args);

        int seqLen = Integer.parseInt(cmd.getOptionValue('s',
                SEQUENCE_LEN + ""));
        int embeddingVectorSize = Integer.parseInt(cmd.getOptionValue('d',
                EMBEDDING_VECTOR_SIZE + ""));

        String embeddingPath = cmd.getOptionValue('e', DATA_PATH + GLOVE_FILE_NAME);
        String inputFilePath = cmd.getOptionValue('i', DATA_PATH
                + NER_TRAIN_FILE_NAME);
        String outputFolderPath = cmd.getOptionValue('o', DATA_PATH);
        boolean binaryOutput = cmd.hasOption('b');

        LabelManager labelManager = new LabelManager();
        LabelManager posTagManager = new LabelManager();
        String[] qgramDict = null;
        if (cmd.hasOption('l')) {
            Yaml yaml = new Yaml();
            @SuppressWarnings("unchecked")
            Map<String, Object> map = (Map<String, Object>) yaml.load(new String(
                    Files.readAllBytes(Paths.get(cmd.getOptionValue('l')))));
            @SuppressWarnings("unchecked")
            Map<Integer, String> labels = (Map<Integer, String>) map.get("labels");
            labelManager = new LabelManager(labels);
            @SuppressWarnings("unchecked")
            Map<Integer, String> posLabels = (Map<Integer, String>) map.get("pos_tags");
            posTagManager = new LabelManager(posLabels);
            @SuppressWarnings("unchecked")
            List<String> dictList = (List<String>) map.get("ngram_dict");
            qgramDict = dictList.toArray(new String[dictList.size()]);
        }

        VectorizerMain m = new VectorizerMain(seqLen,
                embeddingVectorSize, embeddingPath, inputFilePath, outputFolderPath,
                binaryOutput, labelManager, posTagManager, qgramDict);
        m.vectorize();
    }

    private void wordShape(String s, int offset, float[] featureVector) {
        boolean digit = true;
        boolean upper = true;
        boolean lower = true;
        boolean mixed = true;
        boolean firstUpper = Character.isUpperCase(s.charAt(0));
        for (int i = 0; i < s.length(); i++) {
            char c = s.charAt(i);
            if (!Character.isDigit(c)) {
                digit = false;
            }
            if (!Character.isLowerCase(c)) {
                lower = false;
            }
            if (!Character.isUpperCase(c)) {
                upper = false;
            }
            if ((i == 0 && !Character.isUpperCase(c)) || (i >= 1 && !Character.isLowerCase(c))) {
                mixed = false;
            }
        }

        if (digit) {
            featureVector[offset] = 1f;
        }
        if (upper) {
            featureVector[offset + 1] = 1f;
        }
        if (lower) {
            featureVector[offset + 2] = 1f;
        }
        if (mixed) {
            featureVector[offset + 3] = 1f;
        }
        if (firstUpper) {
            featureVector[offset + 4] = 1f;
        }
    }

    private void writeAndFillSequenceIfNeeded(float[] defaultVector, SequenceFileWriter writer,
                                              Deque<Integer> labelBuffer, Deque<float[]> vectorBuffer) throws IOException {
        checkArgument(labelBuffer.size() == vectorBuffer.size(), "seq and feature size don't match");
        int[] labels = new int[Math.max(labelBuffer.size(), sequenceLength)];
        List<float[]> featList = new ArrayList<>();

        for (int i = 0; i < labels.length; i++) {
            if (labelBuffer.isEmpty()) {
                labels[i] = labelManager.getOrCreate("O");
                featList.add(defaultVector);
            } else {
                labels[i] = labelBuffer.pop();
                featList.add(vectorBuffer.pop());
            }
        }

        writer.write(labels, featList);
    }

    private static HashMap<String, float[]> readGloveEmbeddings(String inputFilePath, int embeddingVectorSize)
            throws IOException {
        HashMap<String, float[]> map = new HashMap<>();

        try (BufferedReader reader = new BufferedReader(new FileReader(inputFilePath))) {
            String line;
            while ((line = reader.readLine()) != null) {
                String[] split = SPLIT_PATTERN.split(line);
                if (split.length != embeddingVectorSize + 1) {
                    throw new IllegalArgumentException(
                            "invalid embeddings used, encountered unexpected number of columns! "
                                    + split.length);
                }
                float[] vector = new float[split.length - 1];
                for (int i = 0; i < split.length - 1; i++) {
                    vector[i] = Float.parseFloat(split[i + 1]);
                }
                map.put(split[0], vector);
            }
        }
        return map;
    }

    private static SequenceFileWriter createWriter(String outputFolder,
                                                   boolean binary) throws IOException {
        if (binary) {
            return new SequenceFileWriter.BinaryWriter(outputFolder
                    + TRAIN_OUT_FILE_NAME);
        } else {
            return new SequenceFileWriter.TextWriter(outputFolder
                    + TRAIN_OUT_FILE_NAME);
        }
    }
}