package uk.ac.soton.ecs.comp6237.l3; import java.awt.Component; import java.awt.Dimension; import java.awt.Font; import java.awt.Point; import java.awt.event.ActionEvent; import java.awt.event.ActionListener; import java.awt.event.MouseAdapter; import java.awt.event.MouseEvent; import java.awt.image.BufferedImage; import java.io.IOException; import java.util.ArrayList; import java.util.List; import javax.swing.BoxLayout; import javax.swing.JButton; import javax.swing.JComboBox; import javax.swing.JLabel; import javax.swing.JPanel; import javax.swing.JSeparator; import javax.swing.SwingConstants; import org.openimaj.content.slideshow.Slide; import org.openimaj.content.slideshow.SlideshowApplication; import org.openimaj.feature.FloatFVComparator; import org.openimaj.feature.FloatFVComparison; import org.openimaj.image.DisplayUtilities; import org.openimaj.image.DisplayUtilities.ImageComponent; import org.openimaj.image.ImageUtilities; import org.openimaj.image.MBFImage; import org.openimaj.image.colour.ColourSpace; import org.openimaj.image.colour.RGBColour; import org.openimaj.image.renderer.MBFImageRenderer; import org.openimaj.image.renderer.RenderHints; import org.openimaj.image.typography.FontStyle.HorizontalAlignment; import org.openimaj.image.typography.FontStyle.VerticalAlignment; import org.openimaj.image.typography.general.GeneralFont; import org.openimaj.image.typography.general.GeneralFontStyle; import org.openimaj.math.geometry.point.Point2d; import org.openimaj.math.geometry.point.Point2dImpl; import org.openimaj.math.geometry.shape.Circle; import org.openimaj.math.geometry.shape.Rectangle; import uk.ac.soton.ecs.comp6237.utils.Utils; import uk.ac.soton.ecs.comp6237.utils.annotations.Demonstration; /** * Demo showing Hierarchical Agglomerative Clustering with WPGMC * * @author Jonathon Hare ([email protected]) */ @Demonstration(title = "Hierarchical Agglomerative Clustering Demo") public class HClusterDemo extends MouseAdapter implements Slide, ActionListener { enum Linkage { WPGMC { @Override public double computeDistance(BiCluster a, BiCluster b, FloatFVComparator distanceMeasure) { return distanceMeasure.compare(a.vec, b.vec); } @Override public float[] computeVec(BiCluster left, BiCluster right) { final float[] mergevec = new float[2]; for (int i = 0; i < mergevec.length; i++) mergevec[i] = (left.vec[i] + right.vec[i]) / 2; return mergevec; } }, UPGMC { @Override public double computeDistance(BiCluster a, BiCluster b, FloatFVComparator distanceMeasure) { return distanceMeasure.compare(a.vec, b.vec); } @Override public float[] computeVec(BiCluster left, BiCluster right) { final List<float[]> leaves = new ArrayList<>(); leaves.addAll(left.getAllLeaves()); leaves.addAll(right.getAllLeaves()); final float[] mean = new float[2]; for (final float[] v : leaves) { mean[0] += v[0]; mean[1] += v[1]; } mean[0] /= leaves.size(); mean[1] /= leaves.size(); return mean; } }, Minimum { @Override public double computeDistance(BiCluster a, BiCluster b, FloatFVComparator distanceMeasure) { final List<float[]> leavesA = a.getAllLeaves(); final List<float[]> leavesB = b.getAllLeaves(); double min = Double.MAX_VALUE; for (final float[] la : leavesA) { for (final float[] lb : leavesB) { final double d = distanceMeasure.compare(la, lb); if (d < min) min = d; } } return min; } }, Maximum { @Override public double computeDistance(BiCluster a, BiCluster b, FloatFVComparator distanceMeasure) { final List<float[]> leavesA = a.getAllLeaves(); final List<float[]> leavesB = b.getAllLeaves(); double max = -Double.MAX_VALUE; for (final float[] la : leavesA) { for (final float[] lb : leavesB) { final double d = distanceMeasure.compare(la, lb); if (d > max) max = d; } } return max; } }, UPGMA { @Override public double computeDistance(BiCluster a, BiCluster b, FloatFVComparator distanceMeasure) { final List<float[]> leavesA = a.getAllLeaves(); final List<float[]> leavesB = b.getAllLeaves(); double sum = 0; for (final float[] la : leavesA) { for (final float[] lb : leavesB) { sum += distanceMeasure.compare(la, lb); } } return sum / (leavesA.size() * leavesB.size()); } }, ; public float[] computeVec(BiCluster left, BiCluster right) { return null; } public abstract double computeDistance(BiCluster a, BiCluster b, FloatFVComparator distanceMeasure); } class BiCluster { BiCluster left; BiCluster right; float[] vec; double distance; char label; List<float[]> getAllLeaves() { final List<float[]> leaves = new ArrayList<>(); getAllLeaves(leaves); return leaves; } private void getAllLeaves(List<float[]> leaves) { if (left != null) { left.getAllLeaves(leaves); right.getAllLeaves(leaves); } else { leaves.add(vec); } } } private MBFImage image; private ImageComponent ic; private BufferedImage bimg; private List<Point2d> points = new ArrayList<Point2d>(); private JButton runBtn; private JButton clearBtn; private JButton cnclBtn; private volatile boolean isRunning; private MBFImageRenderer renderer; private FloatFVComparator distanceMeasure = FloatFVComparison.EUCLIDEAN; private JComboBox<String> distCombo; List<BiCluster> clusters = new ArrayList<BiCluster>(); Linkage linkage = Linkage.WPGMC; private JComboBox<String> linkCombo; @Override public Component getComponent(int width, int height) throws IOException { final JPanel base = new JPanel(); base.setOpaque(false); base.setPreferredSize(new Dimension(width, height)); base.setLayout(new BoxLayout(base, BoxLayout.Y_AXIS)); image = new MBFImage(width, height - 50, ColourSpace.RGB); renderer = image.createRenderer(RenderHints.ANTI_ALIASED); resetImage(); ic = new DisplayUtilities.ImageComponent(true, false); ic.setShowPixelColours(false); ic.setShowXYPosition(false); ic.setAllowPanning(false); ic.setAllowZoom(false); ic.addMouseListener(this); ic.addMouseMotionListener(this); base.add(ic); final JPanel controls = new JPanel(); controls.setPreferredSize(new Dimension(width, 50)); controls.setMaximumSize(new Dimension(width, 50)); controls.setSize(new Dimension(width, 50)); clearBtn = new JButton("Clear"); clearBtn.setActionCommand("button.clear"); clearBtn.addActionListener(this); controls.add(clearBtn); controls.add(new JSeparator(SwingConstants.VERTICAL)); controls.add(new JLabel("Distance:")); distCombo = new JComboBox<String>(); distCombo.addItem("Euclidean"); distCombo.addItem("Manhatten"); distCombo.addItem("Cosine Distance"); controls.add(distCombo); controls.add(new JSeparator(SwingConstants.VERTICAL)); controls.add(new JLabel("Linkage:")); linkCombo = new JComboBox<String>(); for (final Linkage s : Linkage.values()) linkCombo.addItem(s.name()); controls.add(linkCombo); controls.add(new JSeparator(SwingConstants.VERTICAL)); runBtn = new JButton("Run HAC"); runBtn.setActionCommand("button.run"); runBtn.addActionListener(this); controls.add(runBtn); controls.add(new JSeparator(SwingConstants.VERTICAL)); cnclBtn = new JButton("Cancel"); cnclBtn.setEnabled(false); cnclBtn.setActionCommand("button.cancel"); cnclBtn.addActionListener(this); controls.add(cnclBtn); base.add(controls); updateImage(); return base; } @Override public void mouseClicked(MouseEvent e) { if (!isRunning && points.size() <= 25) { final Point pt = e.getPoint(); final Point2dImpl pti = new Point2dImpl(pt.x, pt.y); drawPoint(pti, points.size()); points.add(pti); updateImage(); } } private void drawPoint(final Point2dImpl pti, int index) { renderer.drawShapeFilled(new Circle(pti, 20), RGBColour.MAGENTA); final char c = (char) (65 + index); final GeneralFontStyle<Float[]> style = new GeneralFontStyle<Float[]>(new GeneralFont("Arial", Font.BOLD), renderer, false); style.setHorizontalAlignment(HorizontalAlignment.HORIZONTAL_CENTER); renderer.drawText(c + "", (int) (pti.x), (int) (pti.y + 10), style); } private void resetImage() { image.fill(RGBColour.WHITE); points.clear(); } @Override public void mouseDragged(MouseEvent e) { // ignore } private void updateImage() { ic.setImage(bimg = ImageUtilities.createBufferedImageForDisplay(image, bimg)); } private void initHAC() { if (this.distCombo.getSelectedItem().equals("Euclidean")) this.distanceMeasure = FloatFVComparison.EUCLIDEAN; else if (this.distCombo.getSelectedItem().equals("Manhatten")) this.distanceMeasure = FloatFVComparison.CITY_BLOCK; else if (this.distCombo.getSelectedItem().equals("Cosine Distance")) this.distanceMeasure = FloatFVComparison.COSINE_DIST; final String link = (String) this.linkCombo.getSelectedItem(); linkage = Linkage.valueOf(link); clusters.clear(); image.fill(RGBColour.WHITE); int i = 0; for (final Point2d p : points) { final BiCluster c = new BiCluster(); c.vec = toFloatArray(p, new float[2]); c.label = (char) (65 + i); clusters.add(c); drawPoint((Point2dImpl) p, i++); } updateImage(); } /** * Merge the two closest items */ private void mergeStep() { final int[] lowestpair = { 0, 1 }; double closest = Double.MAX_VALUE; for (int i = 0; i < clusters.size(); i++) { for (int j = i + 1; j < clusters.size(); j++) { final double d = linkage.computeDistance(clusters.get(i), clusters.get(j), distanceMeasure); if (d < closest) { closest = d; lowestpair[0] = i; lowestpair[1] = j; } } } final float[] mergevec = linkage.computeVec(clusters.get(lowestpair[0]), clusters.get(lowestpair[1])); final BiCluster newcluster = new BiCluster(); newcluster.vec = mergevec; newcluster.left = clusters.get(lowestpair[0]); newcluster.right = clusters.get(lowestpair[1]); newcluster.distance = closest; final int x = Math.min(computeLeftBound(newcluster.left), computeLeftBound(newcluster.right)) - 30; final int y = Math.min(computeTopBound(newcluster.left), computeTopBound(newcluster.right)) - 30; final int x1 = Math.max(computeRightBound(newcluster.left), computeRightBound(newcluster.right)) + 30; final int y1 = Math.max(computeBottomBound(newcluster.left), computeBottomBound(newcluster.right)) + 30; final Rectangle r = new Rectangle(x, y, x1 - x, y1 - y); renderer.drawShape(r, 3, RGBColour.RED); if (linkage == Linkage.UPGMC || linkage == Linkage.WPGMC) { image.drawLine((int) newcluster.left.vec[0], (int) newcluster.left.vec[1], (int) newcluster.right.vec[0], (int) newcluster.right.vec[1], 1, RGBColour.BLUE); renderer.drawPoint(new Point2dImpl(newcluster.vec[0], newcluster.vec[1]), RGBColour.GREEN, 5); } updateImage(); clusters.remove(lowestpair[1]); clusters.remove(lowestpair[0]); clusters.add(newcluster); } int computeLeftBound(BiCluster a) { if (a.left == null) return (int) a.vec[0]; else return Math.min(computeLeftBound(a.left), computeLeftBound(a.right)) - 5; } int computeRightBound(BiCluster a) { if (a.left == null) return (int) a.vec[0]; else return Math.max(computeRightBound(a.left), computeRightBound(a.right)) + 5; } int computeTopBound(BiCluster a) { if (a.left == null) return (int) a.vec[1]; else return Math.min(computeTopBound(a.left), computeTopBound(a.right)) - 5; } int computeBottomBound(BiCluster a) { if (a.left == null) return (int) a.vec[1]; else return Math.max(computeBottomBound(a.left), computeBottomBound(a.right)) + 5; } private float[] toFloatArray(Point2d pt, float[] arr) { arr[0] = pt.getX(); arr[1] = pt.getY(); return arr; } @Override public void close() { isRunning = false; } @Override public void actionPerformed(ActionEvent e) { if (e.getActionCommand().equals("button.clear")) { resetImage(); updateImage(); } else if (e.getActionCommand().equals("button.run")) { runBtn.setEnabled(false); clearBtn.setEnabled(false); cnclBtn.setEnabled(true); isRunning = true; new Thread(new Runnable() { @Override public void run() { if (isRunning) { initHAC(); try { Thread.sleep(500); } catch (final InterruptedException e) { e.printStackTrace(); } } while (clusters.size() > 1) { if (isRunning) { mergeStep(); try { Thread.sleep(500); } catch (final InterruptedException e) { e.printStackTrace(); } } else break; } drawDendrogram(); updateImage(); runBtn.setEnabled(true); clearBtn.setEnabled(true); cnclBtn.setEnabled(false); isRunning = false; } }).start(); } else if (e.getActionCommand().equals("button.cancel")) { isRunning = false; cnclBtn.setEnabled(false); } } private int getHeight(BiCluster clust) { // Is this an endpoint? Then the height is just 1 if (clust.left == null && clust.right == null) return 1; // Otherwise the height is the same of the heights of each branch return getHeight(clust.left) + getHeight(clust.right); } private float getDepth(BiCluster clust) { // The distance of an endpoint is 0.0 if (clust.left == null && clust.right == null) return 0; // The distance of a branch is the greater of its two sides plus its own // distance return (float) (Math.max(getDepth(clust.left), getDepth(clust.right)) + clust.distance); } void drawDendrogram() { final int w = 300; final BiCluster root = this.clusters.get(0); // height and width final int h = getHeight(root) * 20; final float depth = getDepth(root); // width is fixed, so scale distances accordingly final float scaling = (w - 150.0F) / depth; final int x0 = renderer.getImage().getWidth() - w; renderer.drawLine(x0, 50 + (h / 2), x0 + 10, 50 + (h / 2), 3, RGBColour.RED); drawnode(root, x0 + 10, 50 + (h / 2), scaling); } void drawnode(BiCluster clust, int x, int y, float scaling) { if (clust.left != null) { final float h1 = getHeight(clust.left) * 20; final float h2 = getHeight(clust.right) * 20; final float top = y - (h1 + h2) / 2; final float bottom = y + (h1 + h2) / 2; // Line length final int ll = (int) (clust.distance * scaling); // Vertical line from this cluster to children renderer.drawLine(x, (int) (top + h1 / 2), x, (int) (bottom - h2 / 2), 3, RGBColour.RED); // Horizontal line to left item renderer.drawLine(x, (int) (top + h1 / 2), x + ll, (int) (top + h1 / 2), 3, RGBColour.RED); // Horizontal line to right item renderer.drawLine(x, (int) (bottom - h2 / 2), x + ll, (int) (bottom - h2 / 2), 3, RGBColour.RED); // Call the function to draw the left and right nodes drawnode(clust.left, x + ll, (int) (top + h1 / 2), scaling); drawnode(clust.right, x + ll, (int) (bottom - h2 / 2), scaling); } else { // If this is an endpoint, draw the item label final GeneralFontStyle<Float[]> style = new GeneralFontStyle<Float[]>(new GeneralFont("Arial", Font.PLAIN), renderer, false); style.setColour(RGBColour.RED); style.setVerticalAlignment(VerticalAlignment.VERTICAL_HALF); renderer.drawText(clust.label + "", x + 5, y + 7, style); } } public static void main(String[] args) throws IOException { new SlideshowApplication(new HClusterDemo(), 1024, 768, Utils.BACKGROUND_IMAGE); } }