package uk.ac.soton.ecs.comp6237.l14;

import java.awt.BasicStroke;
import java.awt.Component;
import java.awt.Dimension;
import java.awt.Font;
import java.awt.Graphics;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.image.BufferedImage;
import java.io.IOException;
import java.util.Arrays;
import java.util.Random;

import javax.swing.BoxLayout;
import javax.swing.JButton;
import javax.swing.JPanel;
import javax.swing.JTextField;

import org.jfree.chart.ChartFactory;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.axis.NumberAxis;
import org.jfree.chart.plot.PlotOrientation;
import org.jfree.chart.renderer.xy.XYLineAndShapeRenderer;
import org.jfree.data.xy.DefaultXYDataset;
import org.openimaj.content.slideshow.Slide;

/**
 * Base class for gradient descent demos showing how to fit a line to data
 *
 * @author Jonathon Hare ([email protected])
 *
 */
public abstract class AbstractGradientDescentDemo implements Slide, Runnable {

	protected static class ImageContainer extends Component {
		private static final long serialVersionUID = 1L;
		BufferedImage image;

		public ImageContainer(BufferedImage img) {
			this.image = img;
			this.setPreferredSize(new Dimension(image.getWidth(), image.getHeight()));
		}

		@Override
		public void paint(Graphics g) {
			super.paint(g);
			g.drawImage(image, 0, 0, null);
		}

		public void update(BufferedImage img) {
			image = img;
			this.repaint();
		}
	}

	protected double[] params = { 0, 0 };
	int iter = 0;
	private JFreeChart chart;
	private ImageContainer chartContainer;
	private DefaultXYDataset chartDataset;
	private JTextField paramsField;
	private DefaultXYDataset errorDataset;
	protected double[][] X;
	private JFreeChart errorChart;
	private ImageContainer errorContainer;
	private double[][] errorSeries;
	protected double alpha = 0.01;
	private int maxIter = 500;

	public AbstractGradientDescentDemo() {
		super();
	}

	@SuppressWarnings("deprecation")
	@Override
	public Component getComponent(int width, int height) throws IOException {
		final JPanel base = new JPanel();
		base.setOpaque(false);
		base.setPreferredSize(new Dimension(width, height));
		base.setLayout(new BoxLayout(base, BoxLayout.Y_AXIS));

		chartDataset = new DefaultXYDataset();
		X = createData();
		chartDataset.addSeries("points", X);
		final double[][] lineData = computeLineData();
		chartDataset.addSeries("line", lineData);

		chart = ChartFactory.createXYLineChart(null, "x", "y", chartDataset, PlotOrientation.VERTICAL,
				false, false, false);
		((XYLineAndShapeRenderer) chart.getXYPlot().getRenderer()).setSeriesLinesVisible(0, false);
		((XYLineAndShapeRenderer) chart.getXYPlot().getRenderer()).setSeriesShapesVisible(0, true);
		((NumberAxis) chart.getXYPlot().getDomainAxis()).setRange(-5, 5);
		((NumberAxis) chart.getXYPlot().getRangeAxis()).setRange(-10, 10);

		((XYLineAndShapeRenderer) chart.getXYPlot().getRenderer()).setStroke(new BasicStroke(2.5f));

		chartContainer = new ImageContainer(chart.createBufferedImage(width, height / 2));
		base.add(chartContainer);

		final JPanel bottomPane = new JPanel();
		bottomPane.setPreferredSize(new Dimension(width, height / 2));
		base.add(bottomPane);

		final JPanel controlsdata = new JPanel();
		controlsdata.setLayout(new BoxLayout(controlsdata, BoxLayout.X_AXIS));
		bottomPane.add(controlsdata);
		final JButton button = new JButton("Go");
		controlsdata.add(button);

		button.addActionListener(new ActionListener() {
			@Override
			public void actionPerformed(ActionEvent e) {
				button.setEnabled(false);
				base.requestFocus();
				new Thread(AbstractGradientDescentDemo.this).start();
			}
		});

		paramsField = new JTextField(20);
		paramsField.setOpaque(false);
		paramsField.setFont(Font.decode("Monaco-24"));
		paramsField.setHorizontalAlignment(JTextField.CENTER);
		paramsField.setEditable(false);
		paramsField.setBorder(null);
		paramsField.setText(String.format("%2.2f, %2.2f", params[0], params[1]));
		controlsdata.add(paramsField);

		errorDataset = new DefaultXYDataset();
		errorSeries = new double[][] { { 0 }, { computeError() } };
		errorDataset.addSeries("data", errorSeries);
		errorChart = ChartFactory.createXYLineChart("Error over time",
				"Iteration", "Error", errorDataset,
				PlotOrientation.VERTICAL,
				false, false, false);
		((NumberAxis) errorChart.getXYPlot().getDomainAxis()).setRange(0, 1);
		((NumberAxis) errorChart.getXYPlot().getRangeAxis()).setRange(0, computeError());
		errorContainer = new ImageContainer(errorChart.createBufferedImage((width - 5) / 2, (height - 5) / 2));
		bottomPane.add(errorContainer);

		return base;
	}

	private double computeError() {
		double cost = 0;
		for (int i = 0; i < X[0].length; i++) {
			final double e = error(new double[] { X[0][i], X[1][i] }, params);
			cost += e * e;
		}
		return cost / (2 * X.length);

	}

	protected double error(double[] x, double[] params) {
		return ((x[0] * params[0] + params[1]) - x[1]);
	}

	private double[][] computeLineData() {
		return new double[][] { { -10, 10 }, { -10 * params[0] + params[1], 10 * params[0] + params[1] } };
	}

	private double[][] createData() {
		final Random rng = new Random(0);
		final double[][] data = new double[2][1000];
		for (int i = 0; i < data[0].length; i++) {
			data[0][i] = (rng.nextDouble() - 0.5) * 10;
			data[1][i] = data[0][i] + rng.nextGaussian() + 3;
		}
		return data;
	}

	@Override
	public void close() {
		params = new double[] { 0, 0 };
		iter = 0;
	}

	@Override
	public void run() {
		while (iter < maxIter) {
			iter++;

			performIteration();

			updateDisplay();
		}
	}

	protected void updateDisplay() {
		final double[][] tmp = new double[][] { Arrays.copyOf(errorSeries[0], iter + 1),
				Arrays.copyOf(errorSeries[1], iter + 1) };
		tmp[0][iter] = iter;
		tmp[1][iter] = this.computeError();
		this.errorSeries = tmp;

		chartDataset.removeSeries("line");
		chartDataset.addSeries("line", computeLineData());
		chartContainer.update(chart.createBufferedImage(chartContainer.image.getWidth(),
				chartContainer.image.getHeight()));

		paramsField.setText(String.format("%2.2f, %2.2f", params[0], params[1]));

		errorDataset.removeSeries("data");
		errorDataset.addSeries("data", errorSeries);
		((NumberAxis) errorChart.getXYPlot().getDomainAxis()).setRange(0, iter);
		errorContainer.update(errorChart.createBufferedImage(errorContainer.image.getWidth(),
				errorContainer.image.getHeight()));
	}

	/**
	 * Perform a single iteration (epoch) of gradient descent. Superclasses should
	 * override.
	 */
	protected void performIteration() {

	}

	protected double[] errorv(double[][] X, double[] params) {
		final double[] ev = new double[X[0].length];
		for (int i = 0; i < X[0].length; i++) {
			ev[i] = ((X[0][i] * params[0] + params[1]) - X[1][i]);
		}
		return ev;
	}

}