/* * Copyright 2014 Jonas Luethke */ package de.tuhh.luethke.okde.model; import java.util.ArrayList; import org.ejml.simple.SimpleMatrix; import org.ejml.simple.SimpleSVD; import de.tuhh.luethke.okde.Exceptions.TooManyComponentsException; import de.tuhh.luethke.okde.utility.Matrices.MatrixOps; public class OneComponentDistribution extends BaseSampleDistribution { public OneComponentDistribution(double w, SimpleMatrix mean, SimpleMatrix covariance, SimpleMatrix bandwidth) { super(); mGlobalWeight = w; mGlobalMean = mean; mGlobalCovariance = covariance; mBandwidthMatrix = bandwidth; mForgettingFactor = 1; } public OneComponentDistribution(OneComponentDistribution oneComponentDistribution) { this.mBandwidthMatrix = oneComponentDistribution.getBandwidthMatrix(); this.mGlobalCovariance = oneComponentDistribution.getGlobalCovariance(); this.mGlobalMean = oneComponentDistribution.getGlobalMean(); this.mSubspaceGlobalCovariance = oneComponentDistribution.getSubspaceGlobalCovariance(); this.mSubspaceInverseCovariance = oneComponentDistribution.getSubspaceInverseCovariance(); this.mGlobalWeight = oneComponentDistribution.getGlobalWeight(); } public OneComponentDistribution(TwoComponentDistribution twoComponentDistribution) { this.mBandwidthMatrix = twoComponentDistribution.getBandwidthMatrix(); this.mGlobalCovariance = twoComponentDistribution.getGlobalCovariance(); this.mGlobalMean = twoComponentDistribution.getGlobalMean(); this.mSubspaceGlobalCovariance = twoComponentDistribution.getSubspaceGlobalCovariance(); this.mSubspaceInverseCovariance = twoComponentDistribution.getSubspaceInverseCovariance(); this.mGlobalWeight = twoComponentDistribution.getGlobalWeight(); } /** * Splits a single component distribution into two components as described in the oKDE-paper. * @return a TwoComponentDistribution */ public TwoComponentDistribution split(double parentWeight){ SimpleSVD<?> svd = mGlobalCovariance.svd(true); SimpleMatrix S = svd.getW(); SimpleMatrix V = svd.getV(); SimpleMatrix d = S.extractDiag(); double max = MatrixOps.maxVectorElement(d); int maxIndex = MatrixOps.maxVectorElementIndex(d); int len = mGlobalCovariance.numRows(); SimpleMatrix M = new SimpleMatrix(len,1); M.set(maxIndex, 0, 1.0d); SimpleMatrix dMean = V.mult(M).scale(0.5*Math.sqrt(max)); SimpleMatrix meanSplit1 = mGlobalMean.plus(dMean); SimpleMatrix meanSplit2 = mGlobalMean.minus(dMean); SimpleMatrix dyadMean = mGlobalMean.mult(mGlobalMean.transpose()); SimpleMatrix dyadMeanSplit1 = meanSplit1.mult(meanSplit1.transpose()); SimpleMatrix dyadMeanSplit2 = meanSplit2.mult(meanSplit2.transpose()); SimpleMatrix covSplit = mGlobalCovariance.plus(dyadMean).minus(dyadMeanSplit1.plus(dyadMeanSplit2).scale(0.5)); SimpleMatrix[] means = {meanSplit1, meanSplit2}; SimpleMatrix[] covariances = {covSplit, covSplit}; double[] weights = {0.5, 0.5}; TwoComponentDistribution splitDist = null; try { splitDist = new TwoComponentDistribution(weights, means, covariances, mBandwidthMatrix); splitDist.setGlobalWeight(parentWeight*mGlobalWeight); splitDist.setGlobalCovariance(mGlobalCovariance); splitDist.setGlobalMean(mGlobalMean); } catch (TooManyComponentsException e) { // cant be thrown } return splitDist; } /** * @see de.tuhh.luethke.okde.model.BaseSampleDistribution#evaluate(SimpleMatrix pointVector) */ @Override public double evaluate(SimpleMatrix pointVector) { SimpleMatrix smoothedCov = mGlobalCovariance.plus(mBandwidthMatrix); double d = 0d; double n = mGlobalMean.numRows(); double a = Math.pow(Math.sqrt(2 * Math.PI), n); double tmp = (-0.5d) * pointVector.minus(mGlobalMean).transpose().mult(smoothedCov.invert()).mult(pointVector.minus(mGlobalMean)).trace(); d += ((1 / (a * Math.sqrt(smoothedCov.determinant()))) * Math.exp(tmp)) * mGlobalWeight; return d; } /** * @see de.tuhh.luethke.oKDE.model.BaseSampleDistribution#evaluate(ArrayList<SimpleMatrix> points) */ @Override public ArrayList<Double> evaluate(ArrayList<SimpleMatrix> points) { ArrayList<Double> resultPoints = new ArrayList<Double>(); for (SimpleMatrix point : points) { resultPoints.add(evaluate(point)); } return resultPoints; } /** * @see de.tuhh.luethke.okde.model.BaseSampleDistribution#setBandwidthMatrix(SimpleMatrix mBandwidthMatrix) */ @Override public void setBandwidthMatrix(SimpleMatrix mBandwidthMatrix) { this.mBandwidthMatrix = mBandwidthMatrix; } }