/******************************************************************************* * Copyright 2015, 2016 Taylor G Smith * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. *******************************************************************************/ package com.clust4j.algo.preprocess; import org.apache.commons.math3.exception.DimensionMismatchException; import org.apache.commons.math3.linear.Array2DRowRealMatrix; import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics; import com.clust4j.except.ModelNotFitException; import com.clust4j.utils.MatUtils; import com.clust4j.utils.VecUtils; public class RobustScaler extends Transformer { private static final long serialVersionUID = 9139185680482876266L; volatile private MedianCenterer centerer; volatile double[] scale; private RobustScaler(RobustScaler rs) { this.centerer = null == rs.centerer ? null : rs.centerer.copy(); this.scale = VecUtils.copy(rs.scale); } public RobustScaler() { } @Override protected void checkFit() { if(null == centerer) throw new ModelNotFitException("model not yet fit"); } @Override public RealMatrix inverseTransform(RealMatrix X) { checkFit(); // This effectively copies, so no need to do a copy later double[][] data = X.getData(); final int m = data.length; final int n = data[0].length; if(n != this.centerer.medians.length) throw new DimensionMismatchException(n, this.centerer.medians.length); // First, multiply back by scales for(int j = 0; j < n; j++) { for(int i = 0; i < m; i++) { data[i][j] *= scale[j]; // To avoid a second pass of O(M*N), we // won't call the inverseTransform in the centerer, // we will just explicitly add the median back here. data[i][j] += centerer.medians[j]; } } return new Array2DRowRealMatrix(data, false); } @Override public RobustScaler copy() { return new RobustScaler(this); } @Override public RobustScaler fit(RealMatrix X) { synchronized(fitLock) { this.centerer = new MedianCenterer().fit(X); // Get percentile final int n = X.getColumnDimension(); double[][] transpose = X.transpose().getData(); // top row will be 25th, bottom 75 double[][] quantiles_25_75 = new double[2][n]; // Quantile engine DescriptiveStatistics stats; for(int j = 0; j < n; j++) { stats = new DescriptiveStatistics(); for(int i = 0; i < transpose[j].length; i++) { stats.addValue(transpose[j][i]); } quantiles_25_75[0][j] = stats.getPercentile(25); quantiles_25_75[0][j] = stats.getPercentile(75); } // set the scale this.scale = VecUtils.subtract(quantiles_25_75[1], quantiles_25_75[0]); // If we have a constant value, we might get zeroes in the scale: for(int i = 0; i < scale.length; i++) { if(scale[i] == 0) { scale[i] = 1.0; } } return this; } } @Override public RealMatrix transform(RealMatrix data) { return new Array2DRowRealMatrix(transform(data.getData()), false); } @Override public double[][] transform(double[][] data) { checkFit(); MatUtils.checkDimsForUniformity(data); final int m = data.length; final int n = data[0].length; // Dim mismatch will happen on the median side double[][] centered = centerer.transform(data); // Scale: for(int j = 0; j < n; j++) { for(int i = 0; i < m; i++) { centered[i][j] /= scale[j]; } } return centered; } }