package org.neuroph.adapters.weka;

import java.util.Arrays;
import java.util.Enumeration;
import java.util.Iterator;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.data.DataSetRow;
import org.neuroph.nnet.MultiLayerPerceptron;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSource;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Normalize;

/**
 * Example usage of Neuroph Weka adapters
 * @author Zoran Sevarac
 */
public class WekaNeurophSample {

    public static void main(String[] args) throws Exception {

        // create weka dataset from file
        DataSource dataSource = new DataSource("datasets/iris.arff");
        Instances wekaDataset = dataSource.getDataSet();
        wekaDataset.setClassIndex(4);

        // normalize dataset
        Normalize filter = new Normalize();
        filter.setInputFormat(wekaDataset);
        wekaDataset = Filter.useFilter(wekaDataset, filter);    
        
        // convert weka dataset to neuroph dataset
        DataSet neurophDataset = WekaDataSetConverter.convertWekaToNeurophDataset(wekaDataset, 4, 3);

        // convert back neuroph dataset to weka dataset
        Instances testWekaDataset = WekaDataSetConverter.convertNeurophToWekaDataset(neurophDataset);

        // print out all to compare
        System.out.println("Weka data set from file");
        printDataSet(wekaDataset);
        
        System.out.println("Neuroph data set converted from Weka data set");
        printDataSet(neurophDataset);
        
        System.out.println("Weka data set reconverted from Neuroph data set");
        printDataSet(testWekaDataset);

        System.out.println("Testing WekaNeurophClassifier");
        testNeurophWekaClassifier(wekaDataset);
    }

    /**
     * Prints Neuroph data set
     *
     * @param neurophDataset Dataset Neuroph data set
     */
    public static void printDataSet(DataSet neurophDataset) {
        System.out.println("Neuroph dataset");
        Iterator iterator = neurophDataset.iterator();

        while (iterator.hasNext()) {
            DataSetRow row = (DataSetRow) iterator.next();
            System.out.println("inputs");
            System.out.println(Arrays.toString(row.getInput()));
            if (row.getDesiredOutput().length > 0) {
                System.out.println("outputs");
                System.out.println(Arrays.toString(row.getDesiredOutput()));
               // System.out.println(row.getLabel());
            }
        }
    }

    /**
     * Prints Weka data set
     *
     * @param wekaDataset Instances Weka data set
     */
    private static void printDataSet(Instances wekaDataset) {
        System.out.println("Weka dataset");
        Enumeration en = wekaDataset.enumerateInstances();
        while (en.hasMoreElements()) {
            Instance instance = (Instance) en.nextElement();
            double[] values = instance.toDoubleArray();
            System.out.println(Arrays.toString(values));
            System.out.println(instance.stringValue(instance.classIndex()));
        }
    }

    /**
     * Test NeurophWekaClassifier
     *
     * @param wekaDataset Instances Weka data set
     */
    private static void testNeurophWekaClassifier(Instances wekaDataset) {
        try {
            MultiLayerPerceptron neuralNet = new MultiLayerPerceptron(4, 16, 3);

            // set labels manualy
            neuralNet.getOutputNeurons().get(0).setLabel("Setosa");
            neuralNet.getOutputNeurons().get(1).setLabel("Versicolor");
            neuralNet.getOutputNeurons().get(2).setLabel("Virginica");

            // initialize NeurophWekaClassifier
            WekaNeurophClassifier neurophWekaClassifier = new WekaNeurophClassifier(neuralNet);
            // set class index on data set
            wekaDataset.setClassIndex(4);
            
            // process data set
            neurophWekaClassifier.buildClassifier(wekaDataset);

            // test item
            //double[] item = {5.1, 3.5, 1.4, 0.2, 0.0}; // normalized item is below
            double[] item = {0.22222222222222213, 0.6249999999999999, 0.06779661016949151, 0.04166666666666667, 0};

            // create weka instance for test item
            Instance instance = new DenseInstance(1, item);

            // test classification
            System.out.println("NeurophWekaClassifier - classifyInstance for {5.1, 3.5, 1.4, 0.2}");
            System.out.println("Class idx: "+neurophWekaClassifier.classifyInstance(instance));
            System.out.println("NeurophWekaClassifier - distributionForInstance for {5.1, 3.5, 1.4, 0.2}");
            double dist[] = neurophWekaClassifier.distributionForInstance(instance);            
            for (int i=0; i<dist.length; i++ ) {
                System.out.println("Class "+i+": "+dist[i]);
            }

        } catch (Exception ex) {
            Logger.getLogger(WekaNeurophSample.class.getName()).log(Level.SEVERE, null, ex);
        }

    }
}