package info.semanticanalyzer.classifiers.weka.fiveway;

import com.google.inject.internal.util.Join;
import info.semanticanalyzer.classifiers.weka.SentimentClass;
import weka.classifiers.bayes.NaiveBayesMultinomialText;
import weka.classifiers.evaluation.Evaluation;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;

import java.util.ArrayList;

/**
 * Created by dmitrykan on 27.04.2014.
 */
public class FiveWayMNBTrainer {
    private NaiveBayesMultinomialText classifier;
    private String modelFile;
    private Instances dataRaw;

    public FiveWayMNBTrainer(String outputModel) {
        classifier = new NaiveBayesMultinomialText();
        modelFile = outputModel;

        ArrayList<Attribute> atts = new ArrayList<Attribute>(2);
        ArrayList<String> classVal = new ArrayList<String>();
        classVal.add(SentimentClass.FiveWayClazz.NEGATIVE.name());
        classVal.add(SentimentClass.FiveWayClazz.SOMEWHAT_NEGATIVE.name());
        classVal.add(SentimentClass.FiveWayClazz.NEUTRAL.name());
        classVal.add(SentimentClass.FiveWayClazz.SOMEWHAT_POSITIVE.name());
        classVal.add(SentimentClass.FiveWayClazz.POSITIVE.name());
        atts.add(new Attribute("content",(ArrayList<String>)null));
        atts.add(new Attribute("@@class@@",classVal));

        dataRaw = new Instances("TrainingInstances",atts,10);
    }

    public void addTrainingInstance(SentimentClass.FiveWayClazz fiveWayClazz, String[] words) {
        double[] instanceValue = new double[dataRaw.numAttributes()];
        instanceValue[0] = dataRaw.attribute(0).addStringValue(Join.join(" ", words));
        instanceValue[1] = fiveWayClazz.ordinal();
        dataRaw.add(new DenseInstance(1.0, instanceValue));
        dataRaw.setClassIndex(1);
    }

    public void trainModel() throws Exception {
        classifier.buildClassifier(dataRaw);
    }

    public void testModel() throws Exception {
        Evaluation eTest = new Evaluation(dataRaw);
        eTest.evaluateModel(classifier, dataRaw);
        String strSummary = eTest.toSummaryString();
        System.out.println(strSummary);
    }

    public void showInstances() {
        System.out.println(dataRaw);
    }

    public Instances getDataRaw() {
        return dataRaw;
    }

    public void saveModel() throws Exception {
        weka.core.SerializationHelper.write(modelFile, classifier);
    }

    public void loadModel(String _modelFile) throws Exception {
        this.classifier = (NaiveBayesMultinomialText) weka.core.SerializationHelper.read(_modelFile);
    }

    public SentimentClass.FiveWayClazz classify(String sentence) throws Exception {
        double[] instanceValue = new double[dataRaw.numAttributes()];
        instanceValue[0] = dataRaw.attribute(0).addStringValue(sentence);

        Instance toClassify = new DenseInstance(1.0, instanceValue);
        dataRaw.setClassIndex(1);
        toClassify.setDataset(dataRaw);

        double prediction = this.classifier.classifyInstance(toClassify);

        double distribution[] = this.classifier.distributionForInstance(toClassify);

        return SentimentClass.FiveWayClazz.values()[(int)prediction];

        /*
        if (distribution[0] != distribution[1])
            return SentimentClass.FiveWayClazz.values()[(int)prediction];
        else

            return SentimentClass.FiveWayClazz.NEUTRAL;
    */
    }
}