package org.surus.math; import org.apache.commons.math3.linear.MatrixUtils; import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.linear.RealVector; import org.apache.commons.math3.linear.SingularValueDecomposition; public class RidgeRegression { private RealMatrix X; private SingularValueDecomposition X_svd = null; private double[] Y; private double l2penalty; private double[] coefficients; private double[] standarderrors; private double[] fitted; private double[] residuals; public RidgeRegression(double[][] x, double[] y) { this.X = MatrixUtils.createRealMatrix(x); this.X_svd = null; this.Y = y; this.l2penalty = 0; this.coefficients = null; this.fitted = new double[y.length]; this.residuals = new double[y.length]; } public void updateCoefficients(double l2penalty) { if (this.X_svd == null) { this.X_svd = new SingularValueDecomposition(X); } RealMatrix V = this.X_svd.getV(); double[] s = this.X_svd.getSingularValues(); RealMatrix U = this.X_svd.getU(); for (int i = 0; i < s.length; i++) { s[i] = s[i] / (s[i]*s[i] + l2penalty); } RealMatrix S = MatrixUtils.createRealDiagonalMatrix(s); RealMatrix Z = V.multiply(S).multiply(U.transpose()); this.coefficients = Z.operate(this.Y); this.fitted = this.X.operate(this.coefficients); double errorVariance = 0; for (int i = 0; i < residuals.length; i++) { this.residuals[i] = this.Y[i] - this.fitted[i]; errorVariance += this.residuals[i] * this.residuals[i]; } errorVariance = errorVariance / (X.getRowDimension() - X.getColumnDimension()); RealMatrix errorVarianceMatrix = MatrixUtils.createRealIdentityMatrix(this.Y.length).scalarMultiply(errorVariance); RealMatrix coefficientsCovarianceMatrix = Z.multiply(errorVarianceMatrix).multiply(Z.transpose()); this.standarderrors = getDiagonal(coefficientsCovarianceMatrix); } private double[] getDiagonal(RealMatrix X) { double[] diag = new double[X.getColumnDimension()]; for (int i = 0; i < diag.length; i++) { diag[i] = X.getEntry(i, i); } return diag; } public double getL2penalty() { return l2penalty; } public void setL2penalty(double l2penalty) { this.l2penalty = l2penalty; } public double[] getCoefficients() { return coefficients; } public double[] getStandarderrors() { return standarderrors; } }