package uk.ac.soton.ecs.comp6237.l5;

import java.awt.Component;
import java.awt.Dimension;
import java.awt.Font;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import javax.swing.BoxLayout;
import javax.swing.JButton;
import javax.swing.JComboBox;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.JSeparator;
import javax.swing.SwingConstants;

import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartPanel;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.labels.ItemLabelAnchor;
import org.jfree.chart.labels.ItemLabelPosition;
import org.jfree.chart.labels.StandardXYItemLabelGenerator;
import org.jfree.chart.plot.PlotOrientation;
import org.jfree.chart.renderer.xy.XYItemRenderer;
import org.jfree.chart.renderer.xy.XYLineAndShapeRenderer;
import org.jfree.data.xy.AbstractXYDataset;
import org.jfree.data.xy.XYDataset;
import org.jfree.ui.TextAnchor;
import org.openimaj.content.slideshow.Slide;
import org.openimaj.content.slideshow.SlideshowApplication;
import org.openimaj.feature.FloatFV;
import org.openimaj.feature.FloatFVComparator;
import org.openimaj.feature.FloatFVComparison;
import org.openimaj.math.geometry.line.Line2d;
import org.openimaj.math.geometry.point.Point2dImpl;

import uk.ac.soton.ecs.comp6237.l3.ItemTermData;
import uk.ac.soton.ecs.comp6237.utils.Utils;
import uk.ac.soton.ecs.comp6237.utils.annotations.Demonstration;

/**
 * Demo showing MDS
 *
 * @author Jonathon Hare ([email protected])
 */
@Demonstration(title = "Multidimensional Scaling Demo (Sammon Mapping)")
public class MDSDemo implements Slide, ActionListener {
	private static final int MAX_ITER = 50000;
	private static final double INIT_LEARNING_RATE = 0.005;
	private JButton runBtn;
	private JButton cnclBtn;
	private volatile boolean isRunning;
	private FloatFVComparator distanceMeasure = null;
	private JComboBox<String> distCombo;
	private ItemTermData data = new ItemTermData("moduledata.txt");
	private List<Point2dImpl> points = new ArrayList<Point2dImpl>();
	private double[][] distances = new double[data.getItemNames().size()][data.getItemNames().size()];
	private double[][] fakeDistances = new double[data.getItemNames().size()][data.getItemNames().size()];

	class Dataset extends AbstractXYDataset {
		private static final long serialVersionUID = 1L;

		@Override
		public Number getY(int series, int item) {
			return points.get(item).y;
		}

		@Override
		public Number getX(int series, int item) {
			return points.get(item).x;
		}

		public String getLabel(int series, int item) {
			return data.getItemNames().get(item);
		}

		@Override
		public int getItemCount(int arg0) {
			return data.getItemNames().size();
		}

		@Override
		public int getSeriesCount() {
			return 1;
		}

		@Override
		public Comparable<String> getSeriesKey(int arg0) {
			return "DATA";
		}
	}

	Dataset dataset = new Dataset();
	private JFreeChart chart;
	private ChartPanel chartPanel;
	private JLabel iterLabel;

	public MDSDemo() {
		for (int i = 0; i < dataset.getItemCount(0); i++)
			this.points.add((Point2dImpl) Point2dImpl.createRandomPoint());
	}

	@Override
	public Component getComponent(int width, int height) throws IOException {
		final JPanel base = new JPanel();
		base.setOpaque(false);
		base.setPreferredSize(new Dimension(width, height));
		base.setLayout(new BoxLayout(base, BoxLayout.Y_AXIS));

		chart = ChartFactory.createScatterPlot("", "", "", dataset, PlotOrientation.VERTICAL, false, false, false);

		final XYItemRenderer renderer = new XYLineAndShapeRenderer(false, true) {
			private static final long serialVersionUID = 1L;

			@Override
			public ItemLabelPosition getPositiveItemLabelPosition(int row, int column) {
				return new ItemLabelPosition(ItemLabelAnchor.OUTSIDE6, TextAnchor.TOP_LEFT);
			}
		};
		final Font font = Font.decode("Helvetica Neue-22");
		renderer.setBaseItemLabelFont(font);
		chart.getXYPlot().setRenderer(renderer);
		// chart.getXYPlot().getDomainAxis().setRange(-0.5, 5.5)
		chart.getXYPlot().getDomainAxis().setTickLabelFont(font);
		// chart.getXYPlot().getDomainAxis().setTickUnit(new NumberTickUnit(1))
		// chart.getXYPlot().getRangeAxis().setRange(-0.5, 5.5)
		chart.getXYPlot().getRangeAxis().setTickLabelFont(font);
		// chart.getXYPlot().getRangeAxis().setTickUnit(new NumberTickUnit(1))

		chart.getXYPlot().getRenderer().setBaseItemLabelGenerator(new StandardXYItemLabelGenerator() {
			private static final long serialVersionUID = 1L;

			@Override
			public String generateLabel(XYDataset ds, int series, int item) {
				return ((Dataset) ds).getLabel(series, item);
			};
		});
		chart.getXYPlot().getRenderer().setBaseItemLabelsVisible(true);

		chartPanel = new ChartPanel(chart);
		chart.setBackgroundPaint(new java.awt.Color(255, 255, 255, 255));
		chart.getXYPlot().setBackgroundPaint(java.awt.Color.WHITE);
		chart.getXYPlot().setRangeGridlinePaint(java.awt.Color.GRAY);
		chart.getXYPlot().setDomainGridlinePaint(java.awt.Color.GRAY);

		chartPanel.setSize(width, height - 50);
		chartPanel.setPreferredSize(chartPanel.getSize());
		base.add(chartPanel);

		final JPanel controls = new JPanel();
		controls.setPreferredSize(new Dimension(width, 50));
		controls.setMaximumSize(new Dimension(width, 50));
		controls.setSize(new Dimension(width, 50));

		controls.add(new JSeparator(SwingConstants.VERTICAL));
		controls.add(new JLabel("Distance:"));

		distCombo = new JComboBox<String>();
		distCombo.addItem("Euclidean");
		distCombo.addItem("1-Pearson");
		distCombo.addItem("1-Cosine");
		controls.add(distCombo);

		controls.add(new JSeparator(SwingConstants.VERTICAL));

		runBtn = new JButton("Run MDS");
		runBtn.setActionCommand("button.run");
		runBtn.addActionListener(this);
		controls.add(runBtn);

		controls.add(new JSeparator(SwingConstants.VERTICAL));

		cnclBtn = new JButton("Cancel");
		cnclBtn.setEnabled(false);
		cnclBtn.setActionCommand("button.cancel");
		cnclBtn.addActionListener(this);
		controls.add(cnclBtn);

		base.add(controls);

		controls.add(new JSeparator(SwingConstants.VERTICAL));
		iterLabel = new JLabel("                         ");
		final Dimension size = iterLabel.getPreferredSize();
		iterLabel.setMinimumSize(size);
		iterLabel.setPreferredSize(size);
		controls.add(iterLabel);

		updateImage();

		return base;
	}

	private void initMDS() {
		if (this.distCombo.getSelectedItem().equals("Euclidean"))
			this.distanceMeasure = FloatFVComparison.EUCLIDEAN;
		else if (this.distCombo.getSelectedItem().equals("1-Cosine"))
			this.distanceMeasure = FloatFVComparison.COSINE_DIST;
		else if (this.distCombo.getSelectedItem().equals("1-Pearson")) {
			this.distanceMeasure = new FloatFVComparator() {

				@Override
				public double compare(FloatFV o1, FloatFV o2) {
					return 1 - FloatFVComparison.CORRELATION.compare(o1, o2);
				}

				@Override
				public boolean isDistance() {
					return true;
				}

				@Override
				public double compare(float[] h1, float[] h2) {
					return 1 - FloatFVComparison.CORRELATION.compare(h1, h2);
				}
			};
		}

		// random init
		this.points.clear();
		final float[][] counts = data.getCounts();
		for (int i = 0; i < distances.length; i++)
		{
			this.points.add((Point2dImpl) Point2dImpl.createRandomPoint());

			for (int j = i + 1; j < distances.length; j++) {
				double d = distanceMeasure.compare(counts[i], counts[j]);
				if (d == 0)
					d = 0.001;
				distances[i][j] = d;
				distances[j][i] = d;
			}
		}

		updateImage();
	}

	private void updateImage() {
		chart.getXYPlot().setDataset(chart.getXYPlot().getDataset());
	}

	double performStep(double lasterror, int iter) {
		for (int i = 0; i < distances.length; i++) {
			for (int j = i + 1; j < distances.length; j++) {
				final double d = Line2d.distance(points.get(i), points.get(j));
				fakeDistances[i][j] = d;
				fakeDistances[j][i] = d;
			}
		}

		final Point2dImpl[] grad = new Point2dImpl[distances.length];
		for (int i = 0; i < distances.length; i++)
			grad[i] = new Point2dImpl();

		double totalError = 0;
		for (int k = 0; k < distances.length; k++) {
			for (int j = k + 1; j < distances.length; j++) {
				if (k == j)
					continue;

				final double errorterm = (fakeDistances[j][k] - distances[j][k]) / distances[j][k];

				grad[k].x += ((points.get(k).x - points.get(j).x) / fakeDistances[j][k]) * errorterm;
				grad[k].y += ((points.get(k).y - points.get(j).y) / fakeDistances[j][k]) * errorterm;

				totalError += Math.abs(errorterm);
			}
		}

		if (totalError >= lasterror)
			return totalError;

		final float rate = getLearningRate(iter);
		for (int k = 0; k < distances.length; k++) {
			points.get(k).x -= rate * grad[k].x;
			points.get(k).y -= rate * grad[k].y;
		}

		return totalError;
	}

	private float getLearningRate(int iter) {
		return (float) (INIT_LEARNING_RATE * Math.exp(-iter / MAX_ITER));
	}

	@Override
	public void close() {
		isRunning = false;
	}

	@Override
	public void actionPerformed(ActionEvent e) {
		if (e.getActionCommand().equals("button.clear")) {
			updateImage();
		} else if (e.getActionCommand().equals("button.run")) {
			runBtn.setEnabled(false);
			cnclBtn.setEnabled(true);
			isRunning = true;

			new Thread(new Runnable() {
				@Override
				public void run() {
					if (isRunning) {
						initMDS();
						try {
							Thread.sleep(500);
						} catch (final InterruptedException e) {
							e.printStackTrace();
						}
					}

					int iter = 0;
					double lasterror = Double.MAX_VALUE;
					while (isRunning && iter++ < MAX_ITER) {
						final double thiserror = performStep(lasterror, iter);
						iterLabel.setText(String.format("%4.2f %5d", thiserror, iter));
						updateImage();

						if (thiserror >= lasterror)
							break;

						lasterror = thiserror;
					}
					updateImage();

					runBtn.setEnabled(true);
					cnclBtn.setEnabled(false);
					isRunning = false;
				}
			}).start();
		} else if (e.getActionCommand().equals("button.cancel")) {
			isRunning = false;
			cnclBtn.setEnabled(false);
		}
	}

	public static void main(String[] args) throws IOException {
		new SlideshowApplication(new MDSDemo(), 1024, 768, Utils.BACKGROUND_IMAGE);
	}
}