package org.neuroph.samples.convolution.util; import java.awt.BorderLayout; import java.awt.Color; import java.awt.Dimension; import java.awt.Image; import java.awt.image.BufferedImage; import java.util.ArrayList; import java.util.List; import javax.swing.ImageIcon; import javax.swing.JFrame; import javax.swing.JLabel; import org.neuroph.nnet.comp.layer.FeatureMapLayer; import org.neuroph.nnet.comp.Kernel; import org.neuroph.core.Connection; import org.neuroph.core.Neuron; import org.neuroph.nnet.comp.neuron.BiasNeuron; public class WeightVisualiser { private static final int RATIO = 20; private List<List<Double>> featureDetector; private Kernel kernel; public WeightVisualiser(FeatureMapLayer map, Kernel kernel) { this.kernel = kernel; this.featureDetector = new ArrayList<>(); initWeights(map); } private void initWeights(FeatureMapLayer map) { List<Double> weights = new ArrayList<>(); Neuron neuron = map.getNeuronAt(0); int counter = 0; for (Connection conn : neuron.getInputConnections()) { if (!(conn.getFromNeuron() instanceof BiasNeuron)) { if (counter < kernel.getArea() ) { weights.add(conn.getWeight().getValue()); counter++; } else { featureDetector.add(weights); weights = new ArrayList<>(); weights.add(conn.getWeight().getValue()); counter = 1; } } } featureDetector.add(weights); } public void displayWeights() { for (List<Double> currentKernel : featureDetector) { displayWeight(currentKernel); } } private void displayWeight(List<Double> currentKernel) { JFrame frame = new JFrame("Weight Visualiser: "); frame.setSize(400, 400); JLabel label = new JLabel(); Dimension d = new Dimension(kernel.getWidth() * RATIO, kernel.getHeight() * RATIO); label.setSize(d); label.setPreferredSize(d); frame.getContentPane().add(label, BorderLayout.CENTER); frame.pack(); frame.setVisible(true); BufferedImage image = new BufferedImage(kernel.getWidth(), kernel.getHeight(), BufferedImage.TYPE_BYTE_GRAY); int[] rgb = convertWeightToRGB(currentKernel); image.setRGB(0, 0, kernel.getWidth(), kernel.getHeight(), rgb, 0, kernel.getWidth()); label.setIcon(new ImageIcon(image.getScaledInstance(kernel.getWidth() * RATIO, kernel.getHeight() * RATIO, Image.SCALE_SMOOTH))); } private int[] convertWeightToRGB(List<Double> weights) { normalizeWeights(weights); int[] data = new int[kernel.getWidth() * kernel.getHeight()]; int i = 0; for (Double weight : weights) { int val = (int) (weight * 255); data[i++] = new Color(val, val, val).getRGB(); } return data; } private void normalizeWeights(List<Double> weights) { double min = Double.MAX_VALUE; double max = Double.MIN_VALUE; for (Double weight : weights) { min = Math.min(min, weight); max = Math.max(max, weight); } for (int i = 0; i < weights.size(); i++) { double value = (weights.get(i) - min) / (max - min); weights.set(i, value); } } }