package org.neuroph.contrib.samples.timeseries;

import java.util.Arrays;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.data.DataSetRow;
import org.neuroph.core.events.LearningEvent;
import org.neuroph.core.events.LearningEventListener;
import org.neuroph.core.learning.SupervisedLearning;
import org.neuroph.nnet.MultiLayerPerceptron;
import org.neuroph.nnet.learning.MomentumBackpropagation;
import org.neuroph.util.TransferFunctionType;

/**
 *
 * @author zoran
 */
public class TestTimeSeries implements LearningEventListener {
    NeuralNetwork<?> neuralNet;
    DataSet trainingSet;
    
  public static void main(String[] args) {                
      TestTimeSeries tts = new TestTimeSeries();
      tts.train();
      tts.testNeuralNetwork();     
    }
  
    public void train() {
        // get the path to file with data
        String inputFileName = "C:\\timeseries\\BSW15";
        
        // create MultiLayerPerceptron neural network
        neuralNet = new MultiLayerPerceptron(TransferFunctionType.TANH, 5, 10, 1);
        MomentumBackpropagation learningRule = (MomentumBackpropagation)neuralNet.getLearningRule();
        learningRule.setLearningRate(0.2);
        learningRule.setMomentum(0.5);
        // learningRule.addObserver(this);
        learningRule.addListener(this);        
        
        // create training set from file
         trainingSet = DataSet.createFromFile(inputFileName, 5, 1, "\t", false);
        // train the network with training set
        neuralNet.learn(trainingSet);         
              
        System.out.println("Done training.");          
    }
  
    
    /**
     * Prints network output for the each element from the specified training set.
     * @param neuralNet neural network
     * @param trainingSet training set
     */
    public void testNeuralNetwork() {
        System.out.println("Testing network...");
        for(DataSetRow trainingElement : trainingSet.getRows()) {
            neuralNet.setInput(trainingElement.getInput());
            neuralNet.calculate();
            double[] networkOutput = neuralNet.getOutput();

            System.out.print("Input: " + Arrays.toString( trainingElement.getInput() ) );
            System.out.println(" Output: " + Arrays.toString( networkOutput) );
        }
    }

//	@Override
//	public void update(Observable arg0, Object arg1) {
//		SupervisedLearning rule = (SupervisedLearning)arg0;
//		System.out.println( "Training, Network Epoch " + rule.getCurrentIteration() + ", Error:" + rule.getTotalNetworkError());
//	}

    public void handleLearningEvent(LearningEvent event) {
		SupervisedLearning rule = (SupervisedLearning)event.getSource();
		System.out.println( "Training, Network Epoch " + rule.getCurrentIteration() + ", Error:" + rule.getTotalNetworkError());
    }
    
}